AC-1235 Improvements to inventory import and computed field performance.

This commit is contained in:
Chris Church
2014-05-18 22:10:06 -04:00
parent 0fcfe6114a
commit b14aa0b55d
5 changed files with 717 additions and 265 deletions

View File

@@ -430,7 +430,8 @@ def load_inventory_source(source, all_group=None, group_filter_re=None,
original_all_group = all_group original_all_group = all_group
if not os.path.exists(source): if not os.path.exists(source):
raise IOError('Source does not exist: %s' % source) raise IOError('Source does not exist: %s' % source)
source = os.path.join(os.path.dirname(source) or os.getcwd(), source) source = os.path.join(os.getcwd(), os.path.dirname(source),
os.path.basename(source))
source = os.path.normpath(os.path.abspath(source)) source = os.path.normpath(os.path.abspath(source))
if os.path.isdir(source): if os.path.isdir(source):
all_group = all_group or MemGroup('all', source) all_group = all_group or MemGroup('all', source)
@@ -511,7 +512,7 @@ class Command(NoArgsCommand):
) )
def init_logging(self): def init_logging(self):
log_levels = dict(enumerate([logging.ERROR, logging.INFO, log_levels = dict(enumerate([logging.WARNING, logging.INFO,
logging.DEBUG, 0])) logging.DEBUG, 0]))
self.logger = logging.getLogger('awx.main.commands.inventory_import') self.logger = logging.getLogger('awx.main.commands.inventory_import')
self.logger.setLevel(log_levels.get(self.verbosity, 0)) self.logger.setLevel(log_levels.get(self.verbosity, 0))
@@ -577,15 +578,24 @@ class Command(NoArgsCommand):
# FIXME: Wait or raise error if inventory is being updated by another # FIXME: Wait or raise error if inventory is being updated by another
# source. # source.
def load_into_database(self): def _batch_add_m2m(self, related_manager, *objs, **kwargs):
''' key = (related_manager.instance.pk, related_manager.through._meta.db_table)
Load inventory from in-memory groups to the database, overwriting or flush = bool(kwargs.get('flush', False))
merging as appropriate. if not hasattr(self, '_batch_add_m2m_cache'):
''' self._batch_add_m2m_cache = {}
cached_objs = self._batch_add_m2m_cache.setdefault(key, [])
cached_objs.extend(objs)
if len(cached_objs) > 100 or flush:
if len(cached_objs):
related_manager.add(*cached_objs)
self._batch_add_m2m_cache[key] = []
# Find any hosts in the database without an instance_id set that may def _build_db_instance_id_map(self):
# still have one available via host variables. '''
db_instance_id_map = {} Find any hosts in the database without an instance_id set that may
still have one available via host variables.
'''
self.db_instance_id_map = {}
if self.instance_id_var: if self.instance_id_var:
if self.inventory_source.group: if self.inventory_source.group:
host_qs = self.inventory_source.group.all_hosts host_qs = self.inventory_source.group.all_hosts
@@ -597,11 +607,14 @@ class Command(NoArgsCommand):
instance_id = host.variables_dict.get(self.instance_id_var, '') instance_id = host.variables_dict.get(self.instance_id_var, '')
if not instance_id: if not instance_id:
continue continue
db_instance_id_map[instance_id] = host.pk self.db_instance_id_map[instance_id] = host.pk
# Update instance ID for each imported host and define a mapping of def _build_mem_instance_id_map(self):
# instance IDs to MemHost instances. '''
mem_instance_id_map = {} Update instance ID for each imported host and define a mapping of
instance IDs to MemHost instances.
'''
self.mem_instance_id_map = {}
if self.instance_id_var: if self.instance_id_var:
for mem_host in self.all_group.all_hosts.values(): for mem_host in self.all_group.all_hosts.values():
instance_id = mem_host.variables.get(self.instance_id_var, '') instance_id = mem_host.variables.get(self.instance_id_var, '')
@@ -610,80 +623,113 @@ class Command(NoArgsCommand):
mem_host.name, self.instance_id_var) mem_host.name, self.instance_id_var)
continue continue
mem_host.instance_id = instance_id mem_host.instance_id = instance_id
mem_instance_id_map[instance_id] = mem_host.name self.mem_instance_id_map[instance_id] = mem_host.name
#self.logger.warning('%r', instance_id_map)
# If overwrite is set, for each host in the database that is NOT in def _delete_hosts(self):
# the local list, delete it. When importing from a cloud inventory '''
# source attached to a specific group, only delete hosts beneath that For each host in the database that is NOT in the local list, delete
# group. Delete each host individually so signal handlers will run. it. When importing from a cloud inventory source attached to a
if self.overwrite: specific group, only delete hosts beneath that group. Delete each
if self.inventory_source.group: host individually so signal handlers will run.
del_hosts = self.inventory_source.group.all_hosts '''
# FIXME: Also include hosts from inventory_source.managed_hosts? if settings.SQL_DEBUG:
else: queries_before = len(connection.queries)
del_hosts = self.inventory.hosts.filter(active=True) if self.inventory_source.group:
instance_ids = set(mem_instance_id_map.keys()) del_hosts = self.inventory_source.group.all_hosts
host_pks = set([v for k,v in db_instance_id_map.items() if k in instance_ids]) # FIXME: Also include hosts from inventory_source.managed_hosts?
host_names = set(mem_instance_id_map.values()) - set(self.all_group.all_hosts.keys()) else:
del_hosts = self.inventory.hosts.filter(active=True)
if self.instance_id_var:
instance_ids = set(self.mem_instance_id_map.keys())
host_pks = set([v for k,v in self.db_instance_id_map.items() if k in instance_ids])
host_names = set(self.mem_instance_id_map.values()) - set(self.all_group.all_hosts.keys())
del_hosts = del_hosts.exclude(Q(name__in=host_names) | Q(instance_id__in=instance_ids) | Q(pk__in=host_pks)) del_hosts = del_hosts.exclude(Q(name__in=host_names) | Q(instance_id__in=instance_ids) | Q(pk__in=host_pks))
for host in del_hosts: else:
host_name = host.name del_hosts = del_hosts.exclude(name__in=self.all_group.all_hosts.keys())
host.mark_inactive() for host in del_hosts:
self.logger.info('Deleted host "%s"', host_name) host_name = host.name
host.mark_inactive()#from_inventory_import=True)
self.logger.info('Deleted host "%s"', host_name)
if settings.SQL_DEBUG:
self.logger.warning('host deletions took %d queries for %d hosts',
len(connection.queries) - queries_before,
del_hosts.count())
def _delete_groups(self):
'''
# If overwrite is set, for each group in the database that is NOT in # If overwrite is set, for each group in the database that is NOT in
# the local list, delete it. When importing from a cloud inventory # the local list, delete it. When importing from a cloud inventory
# source attached to a specific group, only delete children of that # source attached to a specific group, only delete children of that
# group. Delete each group individually so signal handlers will run. # group. Delete each group individually so signal handlers will run.
if self.overwrite: '''
if self.inventory_source.group: if settings.SQL_DEBUG:
del_groups = self.inventory_source.group.all_children queries_before = len(connection.queries)
# FIXME: Also include groups from inventory_source.managed_groups? if self.inventory_source.group:
else: del_groups = self.inventory_source.group.all_children
del_groups = self.inventory.groups.filter(active=True) # FIXME: Also include groups from inventory_source.managed_groups?
group_names = set(self.all_group.all_groups.keys()) else:
del_groups = del_groups.exclude(name__in=group_names) del_groups = self.inventory.groups.filter(active=True)
for group in del_groups: group_names = set(self.all_group.all_groups.keys())
group_name = group.name del_groups = del_groups.exclude(name__in=group_names)
group.mark_inactive(recompute=False) for group in del_groups:
self.logger.info('Group "%s" deleted', group_name) group_name = group.name
group.mark_inactive(recompute=False)#from_inventory_import=True)
self.logger.info('Group "%s" deleted', group_name)
if settings.SQL_DEBUG:
self.logger.warning('group deletions took %d queries for %d groups',
len(connection.queries) - queries_before,
del_groups.count())
# If overwrite is set, clear all invalid child relationships for groups def _delete_group_children_and_hosts(self):
# and all invalid host memberships. When importing from a cloud '''
# inventory source attached to a specific group, only clear Clear all invalid child relationships for groups and all invalid host
# relationships for hosts and groups that are beneath the inventory memberships. When importing from a cloud inventory source attached to
# source group. a specific group, only clear relationships for hosts and groups that
if self.overwrite: are beneath the inventory source group.
if self.inventory_source.group: '''
db_groups = self.inventory_source.group.all_children # FIXME: Optimize performance!
else: if settings.SQL_DEBUG:
db_groups = self.inventory.groups.filter(active=True) queries_before = len(connection.queries)
for db_group in db_groups: group_group_count = 0
db_children = db_group.children.filter(active=True) group_host_count = 0
mem_children = self.all_group.all_groups[db_group.name].children if self.inventory_source.group:
mem_children_names = [g.name for g in mem_children] db_groups = self.inventory_source.group.all_children
for db_child in db_children.exclude(name__in=mem_children_names): else:
if db_child not in db_group.children.filter(active=True): db_groups = self.inventory.groups.filter(active=True)
continue for db_group in db_groups:
db_group.children.remove(db_child) db_children = db_group.children.filter(active=True)
self.logger.info('Group "%s" removed from group "%s"', mem_children = self.all_group.all_groups[db_group.name].children
db_child.name, db_group.name) mem_children_names = [g.name for g in mem_children]
db_hosts = db_group.hosts.filter(active=True) for db_child in db_children.exclude(name__in=mem_children_names):
mem_hosts = self.all_group.all_groups[db_group.name].hosts group_group_count += 1
mem_host_names = set([h.name for h in mem_hosts if not h.instance_id]) if db_child not in db_group.children.filter(active=True):
mem_instance_ids = set([h.instance_id for h in mem_hosts if h.instance_id]) continue
db_host_pks = set([v for k,v in db_instance_id_map.items() if k in mem_instance_ids]) db_group.children.remove(db_child)
for db_host in db_hosts.exclude(Q(name__in=mem_host_names) | Q(instance_id__in=mem_instance_ids) | Q(pk__in=db_host_pks)): self.logger.info('Group "%s" removed from group "%s"',
if db_host not in db_group.hosts.filter(active=True): db_child.name, db_group.name)
continue db_hosts = db_group.hosts.filter(active=True)
db_group.hosts.remove(db_host) mem_hosts = self.all_group.all_groups[db_group.name].hosts
self.logger.info('Host "%s" removed from group "%s"', mem_host_names = set([h.name for h in mem_hosts if not h.instance_id])
db_host.name, db_group.name) mem_instance_ids = set([h.instance_id for h in mem_hosts if h.instance_id])
db_host_pks = set([v for k,v in self.db_instance_id_map.items() if k in mem_instance_ids])
for db_host in db_hosts.exclude(Q(name__in=mem_host_names) | Q(instance_id__in=mem_instance_ids) | Q(pk__in=db_host_pks)):
group_host_count += 1
if db_host not in db_group.hosts.filter(active=True):
continue
db_group.hosts.remove(db_host)
self.logger.info('Host "%s" removed from group "%s"',
db_host.name, db_group.name)
if settings.SQL_DEBUG:
self.logger.warning('group-group and group-host deletions took %d queries for %d relationships',
len(connection.queries) - queries_before,
group_group_count + group_host_count)
# Update/overwrite variables from "all" group. If importing from a def _update_inventory(self):
# cloud source attached to a specific group, variables will be set on '''
# the base group, otherwise they will be set on the whole inventory. Update/overwrite variables from "all" group. If importing from a
cloud source attached to a specific group, variables will be set on
the base group, otherwise they will be set on the whole inventory.
'''
if self.inventory_source.group: if self.inventory_source.group:
all_obj = self.inventory_source.group all_obj = self.inventory_source.group
all_obj.inventory_sources.add(self.inventory_source) all_obj.inventory_sources.add(self.inventory_source)
@@ -706,151 +752,262 @@ class Command(NoArgsCommand):
else: else:
self.logger.info('%s variables unmodified', all_name.capitalize()) self.logger.info('%s variables unmodified', all_name.capitalize())
# FIXME: Attribute changes to superuser? def _create_update_groups(self):
'''
# For each group in the local list, create it if it doesn't exist in For each group in the local list, create it if it doesn't exist in the
# the database. Otherwise, update/replace database variables from the database. Otherwise, update/replace database variables from the
# imported data. Associate with the inventory source group if imported data. Associate with the inventory source group if importing
# importing from cloud inventory source. from cloud inventory source.
for k,v in self.all_group.all_groups.iteritems(): '''
variables = json.dumps(v.variables) if settings.SQL_DEBUG:
defaults = dict(variables=variables, description='imported') queries_before = len(connection.queries)
group, created = self.inventory.groups.get_or_create(name=k, inv_src_group = self.inventory_source.group
defaults=defaults) group_names = set(self.all_group.all_groups.keys())
# Access auto one-to-one attribute to create related object. for group in self.inventory.groups.filter(name__in=group_names):
group.inventory_source mem_group = self.all_group.all_groups[group.name]
if created: db_variables = group.variables_dict
self.logger.info('Group "%s" added', k) if self.overwrite_vars or self.overwrite:
db_variables = mem_group.variables
else: else:
db_variables = group.variables_dict db_variables.update(mem_group.variables)
if db_variables != group.variables_dict:
group.variables = json.dumps(db_variables)
group.save(update_fields=['variables'])
if self.overwrite_vars or self.overwrite: if self.overwrite_vars or self.overwrite:
db_variables = v.variables self.logger.info('Group "%s" variables replaced', group.name)
else: else:
db_variables.update(v.variables) self.logger.info('Group "%s" variables updated', group.name)
if db_variables != group.variables_dict: else:
group.variables = json.dumps(db_variables) self.logger.info('Group "%s" variables unmodified', group.name)
group.save(update_fields=['variables']) group_names.remove(group.name)
if self.overwrite_vars or self.overwrite: if inv_src_group and inv_src_group != group:
self.logger.info('Group "%s" variables replaced', k) self._batch_add_m2m(inv_src_group.children, group)
else: self._batch_add_m2m(self.inventory_source.groups, group)
self.logger.info('Group "%s" variables updated', k) for group_name in group_names:
else: mem_group = self.all_group.all_groups[group_name]
self.logger.info('Group "%s" variables unmodified', k) group = self.inventory.groups.create(name=group_name, variables=json.dumps(mem_group.variables), description='imported')
if self.inventory_source.group and self.inventory_source.group != group: # Access auto one-to-one attribute to create related object.
self.inventory_source.group.children.add(group) #group.inventory_source
group.inventory_sources.add(self.inventory_source) InventorySource.objects.create(group=group, inventory=self.inventory, name=('%s (%s)' % (group_name, self.inventory.name)))
self.logger.info('Group "%s" added', group.name)
if inv_src_group:
self._batch_add_m2m(inv_src_group.children, group)
self._batch_add_m2m(self.inventory_source.groups, group)
if inv_src_group:
self._batch_add_m2m(inv_src_group.children, flush=True)
self._batch_add_m2m(self.inventory_source.groups, flush=True)
if settings.SQL_DEBUG:
self.logger.warning('group updates took %d queries for %d groups',
len(connection.queries) - queries_before,
len(self.all_group.all_groups))
# For each host in the local list, create it if it doesn't exist in def _update_db_host_from_mem_host(self, db_host, mem_host):
# the database. Otherwise, update/replace database variables from the # Update host variables.
# imported data. Associate with the inventory source group if db_variables = db_host.variables_dict
# importing from cloud inventory source. if self.overwrite_vars or self.overwrite:
db_variables = mem_host.variables
else:
db_variables.update(mem_host.variables)
update_fields = []
if db_variables != db_host.variables_dict:
db_host.variables = json.dumps(db_variables)
update_fields.append('variables')
# Update host enabled flag.
enabled = None
if self.enabled_var and self.enabled_var in mem_host.variables:
value = mem_host.variables[self.enabled_var]
if self.enabled_value is not None:
enabled = bool(unicode(self.enabled_value) == unicode(value))
else:
enabled = bool(value)
if enabled is not None and db_host.enabled != enabled:
db_host.enabled = enabled
update_fields.append('enabled')
# Update host name.
if mem_host.name != db_host.name:
old_name = db_host.name
db_host.name = mem_host.name
update_fields.append('name')
# Update host instance_id.
if self.instance_id_var:
instance_id = mem_host.variables.get(self.instance_id_var, '')
else:
instance_id = ''
if instance_id != db_host.instance_id:
old_instance_id = db_host.instance_id
db_host.instance_id = instance_id
update_fields.append('instance_id')
# Update host and display message(s) on what changed.
if update_fields:
db_host.save(update_fields=update_fields)
if 'name' in update_fields:
self.logger.info('Host renamed from "%s" to "%s"', old_name, mem_host.name)
if 'instance_id' in update_fields:
if old_instance_id:
self.logger.info('Host "%s" instance_id updated', mem_host.name)
else:
self.logger.info('Host "%s" instance_id added', mem_host.name)
if 'variables' in update_fields:
if self.overwrite_vars or self.overwrite:
self.logger.info('Host "%s" variables replaced', mem_host.name)
else:
self.logger.info('Host "%s" variables updated', mem_host.name)
else:
self.logger.info('Host "%s" variables unmodified', mem_host.name)
if 'enabled' in update_fields:
if enabled:
self.logger.info('Host "%s" is now enabled', mem_host.name)
else:
self.logger.info('Host "%s" is now disabled', mem_host.name)
if self.inventory_source.group:
self._batch_add_m2m(self.inventory_source.group.hosts, db_host)
self._batch_add_m2m(self.inventory_source.hosts, db_host)
#host.update_computed_fields(False, False)
def _create_update_hosts(self):
'''
For each host in the local list, create it if it doesn't exist in the
database. Otherwise, update/replace database variables from the
imported data. Associate with the inventory source group if importing
from cloud inventory source.
'''
if settings.SQL_DEBUG:
queries_before = len(connection.queries)
host_pks_updated = set()
mem_host_pk_map = {}
mem_host_instance_id_map = {}
mem_host_name_map = {}
mem_host_names_to_update = set(self.all_group.all_hosts.keys())
for k,v in self.all_group.all_hosts.iteritems(): for k,v in self.all_group.all_hosts.iteritems():
variables = json.dumps(v.variables) instance_id = ''
defaults = dict(variables=variables, name=k, description='imported') if self.instance_id_var:
instance_id = v.variables.get(self.instance_id_var, '')
if instance_id in self.db_instance_id_map:
mem_host_pk_map[self.db_instance_id_map[instance_id]] = v
elif instance_id:
mem_host_instance_id_map[instance_id] = v
else:
mem_host_name_map[k] = v
# Update all existing hosts where we know the PK based on instance_id.
for db_host in self.inventory.hosts.filter(active=True, pk__in=mem_host_pk_map.keys()):
mem_host = mem_host_pk_map[db_host.pk]
self._update_db_host_from_mem_host(db_host, mem_host)
host_pks_updated.add(db_host.pk)
mem_host_names_to_update.discard(mem_host.name)
# Update all existing hosts where we know the instance_id.
for db_host in self.inventory.hosts.filter(active=True, instance_id__in=mem_host_instance_id_map.keys()).exclude(pk__in=host_pks_updated):
mem_host = mem_host_instance_id_map[db_host.instance_id]
self._update_db_host_from_mem_host(db_host, mem_host)
host_pks_updated.add(db_host.pk)
mem_host_names_to_update.discard(mem_host.name)
# Update all existing hosts by name.
for db_host in self.inventory.hosts.filter(active=True, name__in=mem_host_name_map.keys()).exclude(pk__in=host_pks_updated):
mem_host = mem_host_name_map[db_host.name]
self._update_db_host_from_mem_host(db_host, mem_host)
host_pks_updated.add(db_host.pk)
mem_host_names_to_update.discard(mem_host.name)
# Create any new hosts.
for mem_host_name in mem_host_names_to_update:
mem_host = self.all_group.all_hosts[mem_host_name]
host_attrs = dict(variables=json.dumps(mem_host.variables),
name=mem_host_name, description='imported')
enabled = None enabled = None
if self.enabled_var and self.enabled_var in v.variables: if self.enabled_var and self.enabled_var in mem_host.variables:
value = v.variables[self.enabled_var] value = mem_host.variables[self.enabled_var]
if self.enabled_value is not None: if self.enabled_value is not None:
enabled = bool(unicode(self.enabled_value) == unicode(value)) enabled = bool(unicode(self.enabled_value) == unicode(value))
else: else:
enabled = bool(value) enabled = bool(value)
defaults['enabled'] = enabled host_attrs['enabled'] = enabled
instance_id = ''
if self.instance_id_var: if self.instance_id_var:
instance_id = v.variables.get(self.instance_id_var, '') instance_id = mem_host.variables.get(self.instance_id_var, '')
defaults['instance_id'] = instance_id host_attrs['instance_id'] = instance_id
if instance_id in db_instance_id_map: db_host = self.inventory.hosts.create(**host_attrs)
attrs = {'pk': db_instance_id_map[instance_id]} if enabled is False:
elif instance_id: self.logger.info('Host "%s" added (disabled)', mem_host_name)
attrs = {'instance_id': instance_id}
defaults.pop('instance_id')
else: else:
attrs = {'name': k} self.logger.info('Host "%s" added', mem_host_name)
defaults.pop('name')
attrs['defaults'] = defaults
host, created = self.inventory.hosts.get_or_create(**attrs)
if created:
if enabled is False:
self.logger.info('Host "%s" added (disabled)', k)
else:
self.logger.info('Host "%s" added', k)
#self.logger.info('Host variables: %s', variables)
else:
db_variables = host.variables_dict
if self.overwrite_vars or self.overwrite:
db_variables = v.variables
else:
db_variables.update(v.variables)
update_fields = []
if db_variables != host.variables_dict:
host.variables = json.dumps(db_variables)
update_fields.append('variables')
if enabled is not None and host.enabled != enabled:
host.enabled = enabled
update_fields.append('enabled')
if k != host.name:
old_name = host.name
host.name = k
update_fields.append('name')
if instance_id != host.instance_id:
old_instance_id = host.instance_id
host.instance_id = instance_id
update_fields.append('instance_id')
if update_fields:
host.save(update_fields=update_fields)
if 'name' in update_fields:
self.logger.info('Host renamed from "%s" to "%s"', old_name, k)
if 'instance_id' in update_fields:
if old_instance_id:
self.logger.info('Host "%s" instance_id updated', k)
else:
self.logger.info('Host "%s" instance_id added', k)
if 'variables' in update_fields:
if self.overwrite_vars or self.overwrite:
self.logger.info('Host "%s" variables replaced', k)
else:
self.logger.info('Host "%s" variables updated', k)
else:
self.logger.info('Host "%s" variables unmodified', k)
if 'enabled' in update_fields:
if enabled:
self.logger.info('Host "%s" is now enabled', k)
else:
self.logger.info('Host "%s" is now disabled', k)
if self.inventory_source.group: if self.inventory_source.group:
self.inventory_source.group.hosts.add(host) self._batch_add_m2m(self.inventory_source.group.hosts, db_host)
host.inventory_sources.add(self.inventory_source) self._batch_add_m2m(self.inventory_source.hosts, db_host)
host.update_computed_fields(False, False) #host.update_computed_fields(False, False)
if self.inventory_source.group:
self._batch_add_m2m(self.inventory_source.group.hosts, flush=True)
self._batch_add_m2m(self.inventory_source.hosts, flush=True)
if settings.SQL_DEBUG:
self.logger.warning('host updates took %d queries for %d hosts',
len(connection.queries) - queries_before,
len(self.all_group.all_hosts))
def _create_update_group_children(self):
'''
For each imported group, create all parent-child group relationships.
'''
if settings.SQL_DEBUG:
queries_before = len(connection.queries)
group_names = [k for k,v in self.all_group.all_groups.iteritems() if v.children]
group_group_count = 0
for db_group in self.inventory.groups.filter(name__in=group_names):
mem_group = self.all_group.all_groups[db_group.name]
group_group_count += len(mem_group.children)
child_names = set([g.name for g in mem_group.children])
db_children_qs = self.inventory.groups.filter(name__in=child_names)
for db_child in db_children_qs.filter(children__id=db_group.id):
self.logger.info('Group "%s" already child of group "%s"', db_child.name, db_group.name)
for db_child in db_children_qs.exclude(children__id=db_group.id):
self._batch_add_m2m(db_group.children, db_child)
self.logger.info('Group "%s" added as child of "%s"', db_child.name, db_group.name)
self._batch_add_m2m(db_group.children, flush=True)
if settings.SQL_DEBUG:
self.logger.warning('Group-group updates took %d queries for %d group-group relationships',
len(connection.queries) - queries_before, group_group_count)
def _create_update_group_hosts(self):
# For each host in a mem group, add it to the parent(s) to which it # For each host in a mem group, add it to the parent(s) to which it
# belongs. # belongs.
for k,v in self.all_group.all_groups.iteritems(): if settings.SQL_DEBUG:
if not v.hosts: queries_before = len(connection.queries)
continue group_names = [k for k,v in self.all_group.all_groups.iteritems() if v.hosts]
db_group = self.inventory.groups.get(name=k) group_host_count = 0
for h in v.hosts: for db_group in self.inventory.groups.filter(name__in=group_names):
if h.instance_id: mem_group = self.all_group.all_groups[db_group.name]
db_host = self.inventory.hosts.get(instance_id=h.instance_id) group_host_count += len(mem_group.hosts)
else: host_names = set([h.name for h in mem_group.hosts if not h.instance_id])
db_host = self.inventory.hosts.get(name=h.name) host_instance_ids = set([h.instance_id for h in mem_group.hosts if h.instance_id])
if db_host not in db_group.hosts.all(): db_hosts_qs = self.inventory.hosts.filter(Q(name__in=host_names) | Q(instance_id__in=host_instance_ids))
db_group.hosts.add(db_host) for db_host in db_hosts_qs.filter(groups__id=db_group.id):
self.logger.info('Host "%s" added to group "%s"', h.name, k) self.logger.info('Host "%s" already in group "%s"', db_host.name, db_group.name)
else: for db_host in db_hosts_qs.exclude(groups__id=db_group.id):
self.logger.info('Host "%s" already in group "%s"', h.name, k) self._batch_add_m2m(db_group.hosts, db_host)
self.logger.info('Host "%s" added to group "%s"', db_host.name, db_group.name)
self._batch_add_m2m(db_group.hosts, flush=True)
if settings.SQL_DEBUG:
self.logger.warning('Group-host updates took %d queries for %d group-host relationships',
len(connection.queries) - queries_before, group_host_count)
# for each group, draw in child group arrangements def load_into_database(self):
for k,v in self.all_group.all_groups.iteritems(): '''
if not v.children: Load inventory from in-memory groups to the database, overwriting or
continue merging as appropriate.
db_group = self.inventory.groups.get(name=k) '''
for g in v.children: # FIXME: Attribute changes to superuser?
db_child = self.inventory.groups.get(name=g.name) self._build_db_instance_id_map()
if db_child not in db_group.hosts.all(): self._build_mem_instance_id_map()
db_group.children.add(db_child) if self.overwrite:
self.logger.info('Group "%s" added as child of "%s"', g.name, k) self._delete_hosts()
else: self._delete_groups()
self.logger.info('Group "%s" already child of group "%s"', g.name, k) self._delete_group_children_and_hosts()
self._update_inventory()
self._create_update_groups()
self._create_update_hosts()
self._create_update_group_children()
self._create_update_group_hosts()
def check_license(self): def check_license(self):
reader = LicenseReader() reader = LicenseReader()
@@ -914,6 +1071,9 @@ class Command(NoArgsCommand):
status, tb, exc = 'error', '', None status, tb, exc = 'error', '', None
try: try:
if settings.SQL_DEBUG:
queries_before = len(connection.queries)
# Update inventory update for this command line invocation. # Update inventory update for this command line invocation.
with ignore_inventory_computed_fields(): with ignore_inventory_computed_fields():
if self.inventory_update: if self.inventory_update:
@@ -935,7 +1095,12 @@ class Command(NoArgsCommand):
else: else:
with disable_activity_stream(): with disable_activity_stream():
self.load_into_database() self.load_into_database()
if settings.SQL_DEBUG:
queries_before2 = len(connection.queries)
self.inventory.update_computed_fields() self.inventory.update_computed_fields()
if settings.SQL_DEBUG:
self.logger.warning('update computed fields took %d queries',
len(connection.queries) - queries_before2)
self.check_license() self.check_license()
if self.inventory_source.group: if self.inventory_source.group:
@@ -943,13 +1108,18 @@ class Command(NoArgsCommand):
else: else:
inv_name = '"%s" (id=%s)' % (self.inventory.name, inv_name = '"%s" (id=%s)' % (self.inventory.name,
self.inventory.id) self.inventory.id)
self.logger.info('Inventory import completed for %s in %0.1fs', if settings.SQL_DEBUG:
inv_name, time.time() - begin) self.logger.warning('Inventory import completed for %s in %0.1fs',
inv_name, time.time() - begin)
else:
self.logger.info('Inventory import completed for %s in %0.1fs',
inv_name, time.time() - begin)
status = 'successful' status = 'successful'
if settings.DEBUG: if settings.SQL_DEBUG:
sqltime = sum(float(x['time']) for x in connection.queries) queries_this_import = connection.queries[queries_before:]
self.logger.info('Inventory import required %d queries ' sqltime = sum(float(x['time']) for x in queries_this_import)
'taking %0.3fs', len(connection.queries), self.logger.warning('Inventory import required %d queries '
'taking %0.3fs', len(queries_this_import),
sqltime) sqltime)
except Exception, e: except Exception, e:
if isinstance(e, KeyboardInterrupt): if isinstance(e, KeyboardInterrupt):

View File

@@ -21,7 +21,7 @@ import zmq
# Django # Django
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
from django.db.models import CASCADE, SET_NULL, PROTECT from django.db.models import Q
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.core.exceptions import ValidationError, NON_FIELD_ERRORS from django.core.exceptions import ValidationError, NON_FIELD_ERRORS
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
@@ -40,6 +40,7 @@ __all__ = ['Inventory', 'Host', 'Group', 'InventorySource', 'InventoryUpdate']
logger = logging.getLogger('awx.main.models.inventory') logger = logging.getLogger('awx.main.models.inventory')
class Inventory(CommonModel): class Inventory(CommonModel):
''' '''
an inventory source contains lists and hosts. an inventory source contains lists and hosts.
@@ -120,24 +121,174 @@ class Inventory(CommonModel):
variables_dict = VarsDictProperty('variables') variables_dict = VarsDictProperty('variables')
def get_group_hosts_map(self, active=None):
'''
Return dictionary mapping group_id to set of child host_id's.
'''
# FIXME: Cache this mapping?
group_hosts_kw = dict(group__inventory_id=self.pk, host__inventory_id=self.pk)
if active is not None:
group_hosts_kw['group__active'] = active
group_hosts_kw['host__active'] = active
group_hosts_qs = Group.hosts.through.objects.filter(**group_hosts_kw)
group_hosts_qs = group_hosts_qs.values_list('group_id', 'host_id')
group_hosts_map = {}
for group_id, host_id in group_hosts_qs:
group_host_ids = group_hosts_map.setdefault(group_id, set())
group_host_ids.add(host_id)
return group_hosts_map
def get_group_parents_map(self, active=None):
'''
Return dictionary mapping group_id to set of parent group_id's.
'''
# FIXME: Cache this mapping?
group_parents_kw = dict(from_group__inventory_id=self.pk, to_group__inventory_id=self.pk)
if active is not None:
group_parents_kw['from_group__active'] = active
group_parents_kw['to_group__active'] = active
group_parents_qs = Group.parents.through.objects.filter(**group_parents_kw)
group_parents_qs = group_parents_qs.values_list('from_group_id', 'to_group_id')
group_parents_map = {}
for from_group_id, to_group_id in group_parents_qs:
group_parents = group_parents_map.setdefault(from_group_id, set())
group_parents.add(to_group_id)
return group_parents_map
def get_group_children_map(self, active=None):
'''
Return dictionary mapping group_id to set of child group_id's.
'''
# FIXME: Cache this mapping?
group_parents_kw = dict(from_group__inventory_id=self.pk, to_group__inventory_id=self.pk)
if active is not None:
group_parents_kw['from_group__active'] = active
group_parents_kw['to_group__active'] = active
group_parents_qs = Group.parents.through.objects.filter(**group_parents_kw)
group_parents_qs = group_parents_qs.values_list('from_group_id', 'to_group_id')
group_children_map = {}
for from_group_id, to_group_id in group_parents_qs:
group_children = group_children_map.setdefault(to_group_id, set())
group_children.add(from_group_id)
return group_children_map
def update_host_computed_fields(self):
'''
Update computed fields for all active hosts in this inventory.
'''
hosts_to_update = {}
hosts_qs = self.hosts.filter(active=True)
# Define queryset of all hosts with active failures.
hosts_with_active_failures = hosts_qs.filter(last_job_host_summary__isnull=False, last_job_host_summary__job__active=True, last_job_host_summary__failed=True).values_list('pk', flat=True)
# Find all hosts that need the has_active_failures flag set.
hosts_to_set = hosts_qs.filter(has_active_failures=False, pk__in=hosts_with_active_failures)
for host_pk in hosts_to_set.values_list('pk', flat=True):
host_updates = hosts_to_update.setdefault(host_pk, {})
host_updates['has_active_failures'] = True
# Find all hosts that need the has_active_failures flag cleared.
hosts_to_clear = hosts_qs.filter(has_active_failures=True).exclude(pk__in=hosts_with_active_failures)
for host_pk in hosts_to_clear.values_list('pk', flat=True):
host_updates = hosts_to_update.setdefault(host_pk, {})
host_updates['has_active_failures'] = False
# Define queryset of all hosts with cloud inventory sources.
hosts_with_cloud_inventory = hosts_qs.filter(inventory_sources__active=True, inventory_sources__source__in=CLOUD_INVENTORY_SOURCES).values_list('pk', flat=True)
# Find all hosts that need the has_inventory_sources flag set.
hosts_to_set = hosts_qs.filter(has_inventory_sources=False, pk__in=hosts_with_cloud_inventory)
for host_pk in hosts_to_set.values_list('pk', flat=True):
host_updates = hosts_to_update.setdefault(host_pk, {})
host_updates['has_inventory_sources'] = True
# Find all hosts that need the has_inventory_sources flag cleared.
hosts_to_clear = hosts_qs.filter(has_inventory_sources=True).exclude(pk__in=hosts_with_cloud_inventory)
for host_pk in hosts_to_clear.values_list('pk', flat=True):
host_updates = hosts_to_updates.setdefault(host_pk, {})
host_updates['has_inventory_sources'] = False
# Now apply updates to hosts where needed.
for host in hosts_qs.filter(pk__in=hosts_to_update.keys()):
host_updates = hosts_to_update[host.pk]
for field, value in host_updates.items():
setattr(host, field, value)
host.save(update_fields=host_updates.keys())
def update_group_computed_fields(self):
'''
Update computed fields for all active groups in this inventory.
'''
group_children_map = self.get_group_children_map(active=True)
group_hosts_map = self.get_group_hosts_map(active=True)
active_host_pks = set(self.hosts.filter(active=True).values_list('pk', flat=True))
failed_host_pks = set(self.hosts.filter(active=True, last_job_host_summary__job__active=True, last_job_host_summary__failed=True).values_list('pk', flat=True))
active_group_pks = set(self.groups.filter(active=True).values_list('pk', flat=True))
failed_group_pks = set() # Update below as we check each group.
groups_with_cloud_pks = set(self.groups.filter(active=True, inventory_sources__active=True, inventory_sources__source__in=CLOUD_INVENTORY_SOURCES).values_list('pk', flat=True))
groups_to_update = {}
# Build list of group pks to check, starting with the groups at the
# deepest level within the tree.
root_group_pks = set(self.root_groups.values_list('pk', flat=True))
group_depths = {} # pk: max_depth
def update_group_depths(group_pk, current_depth=0):
max_depth = group_depths.get(group_pk, 0)
if current_depth > max_depth:
group_depths[group_pk] = current_depth
for child_pk in group_children_map.get(group_pk, set()):
update_group_depths(child_pk, current_depth + 1)
for group_pk in root_group_pks:
update_group_depths(group_pk)
group_pks_to_check = [x[1] for x in sorted([(v,k) for k,v in group_depths.items()], reverse=True)]
for group_pk in group_pks_to_check:
# Get all children and host pks for this group.
parent_pks_to_check = set([group_pk])
parent_pks_checked = set()
child_pks = set()
host_pks = set()
while parent_pks_to_check:
for parent_pk in list(parent_pks_to_check):
c_ids = group_children_map.get(parent_pk, set())
child_pks.update(c_ids)
parent_pks_to_check.remove(parent_pk)
parent_pks_checked.add(parent_pk)
parent_pks_to_check.update(c_ids - parent_pks_checked)
h_ids = group_hosts_map.get(parent_pk, set())
host_pks.update(h_ids)
# Define updates needed for this group.
group_updates = groups_to_update.setdefault(group_pk, {})
group_updates.update({
'total_hosts': len(active_host_pks & host_pks),
'has_active_failures': bool(failed_host_pks & host_pks),
'hosts_with_active_failures': len(failed_host_pks & host_pks),
'total_groups': len(child_pks),
'groups_with_active_failures': len(failed_group_pks & child_pks),
'has_inventory_sources': bool(group_pk in groups_with_cloud_pks),
})
if group_updates['has_active_failures']:
failed_group_pks.add(group_pk)
# Now apply updates to each group as needed.
for group in self.groups.filter(pk__in=groups_to_update.keys()):
group_updates = groups_to_update[group.pk]
for field, value in group_updates.items():
if getattr(group, field) != value:
setattr(group, field, value)
else:
group_updates.pop(field)
if group_updates:
group.save(update_fields=group_updates.keys())
def update_computed_fields(self, update_groups=True, update_hosts=True): def update_computed_fields(self, update_groups=True, update_hosts=True):
''' '''
Update model fields that are computed from database relationships. Update model fields that are computed from database relationships.
''' '''
logger.debug("Going to update inventory computed fields") logger.debug("Going to update inventory computed fields")
if update_hosts: if update_hosts:
for host in self.hosts.filter(active=True): self.update_host_computed_fields()
host.update_computed_fields(update_inventory=False,
update_groups=False)
if update_groups: if update_groups:
for group in self.groups.filter(active=True): self.update_group_computed_fields()
group.update_computed_fields()
active_hosts = self.hosts.filter(active=True) active_hosts = self.hosts.filter(active=True)
failed_hosts = active_hosts.filter(has_active_failures=True) failed_hosts = active_hosts.filter(has_active_failures=True)
active_groups = self.groups.filter(active=True) active_groups = self.groups.filter(active=True)
failed_groups = active_groups.filter(has_active_failures=True) failed_groups = active_groups.filter(has_active_failures=True)
active_inventory_sources = self.inventory_sources.filter(active=True, source__in=CLOUD_INVENTORY_SOURCES) active_inventory_sources = self.inventory_sources.filter(active=True, source__in=CLOUD_INVENTORY_SOURCES)
#failed_inventory_sources = active_inventory_sources.filter(last_update_failed=True)
failed_inventory_sources = active_inventory_sources.filter(last_job_failed=True) failed_inventory_sources = active_inventory_sources.filter(last_job_failed=True)
computed_fields = { computed_fields = {
'has_active_failures': bool(failed_hosts.count()), 'has_active_failures': bool(failed_hosts.count()),
@@ -232,14 +383,15 @@ class Host(CommonModelNameNotUnique):
def get_absolute_url(self): def get_absolute_url(self):
return reverse('api:host_detail', args=(self.pk,)) return reverse('api:host_detail', args=(self.pk,))
def mark_inactive(self, save=True): def mark_inactive(self, save=True, from_inventory_import=False):
''' '''
When marking hosts inactive, remove all associations to related When marking hosts inactive, remove all associations to related
inventory sources. inventory sources.
''' '''
super(Host, self).mark_inactive(save=save) super(Host, self).mark_inactive(save=save)
self.inventory_sources.clear() if not from_inventory_import:
self.clear_cached_values() self.inventory_sources.clear()
self.clear_cached_values()
def update_computed_fields(self, update_inventory=True, update_groups=True): def update_computed_fields(self, update_inventory=True, update_groups=True):
''' '''
@@ -280,10 +432,19 @@ class Host(CommonModelNameNotUnique):
Return all groups of which this host is a member, avoiding infinite Return all groups of which this host is a member, avoiding infinite
recursion in the case of cyclical group relations. recursion in the case of cyclical group relations.
''' '''
qs = self.groups.distinct() group_parents_map = self.inventory.get_group_parents_map()
for group in self.groups.all(): group_pks = set(self.groups.values_list('pk', flat=True))
qs = qs | group.all_parents child_pks_to_check = set()
return qs child_pks_to_check.update(group_pks)
child_pks_checked = set()
while child_pks_to_check:
for child_pk in list(child_pks_to_check):
p_ids = group_parents_map.get(child_pk, set())
group_pks.update(p_ids)
child_pks_to_check.remove(child_pk)
child_pks_checked.add(child_pk)
child_pks_to_check.update(p_ids - child_pks_checked)
return Group.objects.filter(pk__in=group_pks).distinct()
def update_cached_values(self): def update_cached_values(self):
cacheable_data = {"%s_all_groups" % self.id: [{'id': g.id, 'name': g.name} for g in self.all_groups.all()], cacheable_data = {"%s_all_groups" % self.id: [{'id': g.id, 'name': g.name} for g in self.all_groups.all()],
@@ -422,7 +583,7 @@ class Group(CommonModelNameNotUnique):
mark_actual() mark_actual()
update_inventory_computed_fields.delay(self.id, True) update_inventory_computed_fields.delay(self.id, True)
def mark_inactive(self, save=True, recompute=True): def mark_inactive(self, save=True, recompute=True, from_inventory_import=False):
''' '''
When marking groups inactive, remove all associations to related When marking groups inactive, remove all associations to related
groups/hosts/inventory_sources. groups/hosts/inventory_sources.
@@ -436,7 +597,9 @@ class Group(CommonModelNameNotUnique):
self.hosts.clear() self.hosts.clear()
i = self.inventory i = self.inventory
if recompute: if from_inventory_import:
super(Group, self).mark_inactive(save=save)
elif recompute:
with ignore_inventory_computed_fields(): with ignore_inventory_computed_fields():
mark_actual() mark_actual()
i.update_computed_fields() i.update_computed_fields()
@@ -475,16 +638,21 @@ class Group(CommonModelNameNotUnique):
def get_all_parents(self, except_pks=None): def get_all_parents(self, except_pks=None):
''' '''
Return all parents of this group recursively, avoiding infinite Return all parents of this group recursively. The group itself will
recursion in the case of cyclical relations. The group itself will be be excluded unless there is a cycle leading back to it.
excluded unless there is a cycle leading back to it.
''' '''
except_pks = except_pks or set() group_parents_map = self.inventory.get_group_parents_map()
except_pks.add(self.pk) child_pks_to_check = set([self.pk])
qs = self.parents.distinct() child_pks_checked = set()
for group in self.parents.exclude(pk__in=except_pks): parent_pks = set()
qs = qs | group.get_all_parents(except_pks) while child_pks_to_check:
return qs for child_pk in list(child_pks_to_check):
p_ids = group_parents_map.get(child_pk, set())
parent_pks.update(p_ids)
child_pks_to_check.remove(child_pk)
child_pks_checked.add(child_pk)
child_pks_to_check.update(p_ids - child_pks_checked)
return Group.objects.filter(pk__in=parent_pks).distinct()
@property @property
def all_parents(self): def all_parents(self):
@@ -492,16 +660,21 @@ class Group(CommonModelNameNotUnique):
def get_all_children(self, except_pks=None): def get_all_children(self, except_pks=None):
''' '''
Return all children of this group recursively, avoiding infinite Return all children of this group recursively. The group itself will
recursion in the case of cyclical relations. The group itself will be be excluded unless there is a cycle leading back to it.
excluded unless there is a cycle leading back to it.
''' '''
except_pks = except_pks or set() group_children_map = self.inventory.get_group_children_map()
except_pks.add(self.pk) parent_pks_to_check = set([self.pk])
qs = self.children.distinct() parent_pks_checked = set()
for group in self.children.exclude(pk__in=except_pks): child_pks = set()
qs = qs | group.get_all_children(except_pks) while parent_pks_to_check:
return qs for parent_pk in list(parent_pks_to_check):
c_ids = group_children_map.get(parent_pk, set())
child_pks.update(c_ids)
parent_pks_to_check.remove(parent_pk)
parent_pks_checked.add(parent_pk)
parent_pks_to_check.update(c_ids - parent_pks_checked)
return Group.objects.filter(pk__in=child_pks).distinct()
@property @property
def all_children(self): def all_children(self):
@@ -509,15 +682,22 @@ class Group(CommonModelNameNotUnique):
def get_all_hosts(self, except_group_pks=None): def get_all_hosts(self, except_group_pks=None):
''' '''
Return all hosts associated with this group or any of its children, Return all hosts associated with this group or any of its children.
avoiding infinite recursion in the case of cyclical group relations.
''' '''
except_group_pks = except_group_pks or set() group_children_map = self.inventory.get_group_children_map()
except_group_pks.add(self.pk) group_hosts_map = self.inventory.get_group_hosts_map()
qs = self.hosts.distinct() parent_pks_to_check = set([self.pk])
for group in self.children.exclude(pk__in=except_group_pks): parent_pks_checked = set()
qs = qs | group.get_all_hosts(except_group_pks) host_pks = set()
return qs while parent_pks_to_check:
for parent_pk in list(parent_pks_to_check):
c_ids = group_children_map.get(parent_pk, set())
parent_pks_to_check.remove(parent_pk)
parent_pks_checked.add(parent_pk)
parent_pks_to_check.update(c_ids - parent_pks_checked)
h_ids = group_hosts_map.get(parent_pk, set())
host_pks.update(h_ids)
return Host.objects.filter(pk__in=host_pks).distinct()
@property @property
def all_hosts(self): def all_hosts(self):

View File

@@ -445,7 +445,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
contact_name='AWX Admin', contact_name='AWX Admin',
contact_email='awx@example.com', contact_email='awx@example.com',
license_date=int(time.time() + 3600), license_date=int(time.time() + 3600),
instance_count=500, instance_count=10000,
) )
handle, license_path = tempfile.mkstemp(suffix='.json') handle, license_path = tempfile.mkstemp(suffix='.json')
os.close(handle) os.close(handle)
@@ -565,7 +565,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
result, stdout, stderr = self.run_command('inventory_import', result, stdout, stderr = self.run_command('inventory_import',
inventory_id=new_inv.pk, inventory_id=new_inv.pk,
source=inv_src) source=inv_src)
self.assertEqual(result, None) self.assertEqual(result, None, stdout + stderr)
# Check that inventory is populated as expected. # Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk) new_inv = Inventory.objects.get(pk=new_inv.pk)
expected_group_names = set(['servers', 'dbservers', 'webservers']) expected_group_names = set(['servers', 'dbservers', 'webservers'])
@@ -637,7 +637,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
source=self.ini_path, source=self.ini_path,
overwrite=overwrite, overwrite=overwrite,
overwrite_vars=overwrite_vars) overwrite_vars=overwrite_vars)
self.assertEqual(result, None) self.assertEqual(result, None, stdout + stderr)
# Check that inventory is populated as expected. # Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk) new_inv = Inventory.objects.get(pk=new_inv.pk)
expected_group_names = set(['servers', 'dbservers', 'webservers', expected_group_names = set(['servers', 'dbservers', 'webservers',
@@ -828,7 +828,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
result, stdout, stderr = self.run_command('inventory_import', result, stdout, stderr = self.run_command('inventory_import',
inventory_id=new_inv.pk, inventory_id=new_inv.pk,
source=source) source=source)
self.assertEqual(result, None) self.assertEqual(result, None, stdout + stderr)
# Check that inventory is populated as expected. # Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk) new_inv = Inventory.objects.get(pk=new_inv.pk)
self.assertEqual(old_inv.variables_dict, new_inv.variables_dict) self.assertEqual(old_inv.variables_dict, new_inv.variables_dict)
@@ -860,14 +860,13 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
new_inv = self.organizations[0].inventories.create(name='newec2') new_inv = self.organizations[0].inventories.create(name='newec2')
self.assertEqual(new_inv.hosts.count(), 0) self.assertEqual(new_inv.hosts.count(), 0)
self.assertEqual(new_inv.groups.count(), 0) self.assertEqual(new_inv.groups.count(), 0)
#inv_file = os.path.join(os.path.dirname(__file__), 'data',
# 'large_ec2_inventory.py')
os.chdir(os.path.join(os.path.dirname(__file__), 'data')) os.chdir(os.path.join(os.path.dirname(__file__), 'data'))
inv_file = 'large_ec2_inventory.py' inv_file = 'large_ec2_inventory.py'
settings.DEBUG = True
result, stdout, stderr = self.run_command('inventory_import', result, stdout, stderr = self.run_command('inventory_import',
inventory_id=new_inv.pk, inventory_id=new_inv.pk,
source=inv_file) source=inv_file)
self.assertEqual(result, None, stdout+stderr) self.assertEqual(result, None, stdout + stderr)
# Check that inventory is populated as expected within a reasonable # Check that inventory is populated as expected within a reasonable
# amount of time. Computed fields should also be updated. # amount of time. Computed fields should also be updated.
new_inv = Inventory.objects.get(pk=new_inv.pk) new_inv = Inventory.objects.get(pk=new_inv.pk)
@@ -875,5 +874,45 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
self.assertNotEqual(new_inv.groups.count(), 0) self.assertNotEqual(new_inv.groups.count(), 0)
self.assertNotEqual(new_inv.total_hosts, 0) self.assertNotEqual(new_inv.total_hosts, 0)
self.assertNotEqual(new_inv.total_groups, 0) self.assertNotEqual(new_inv.total_groups, 0)
self.assertElapsedLessThan(60) self.assertElapsedLessThan(30)
def _get_ngroups_for_nhosts(self, n):
if n > 0:
return min(n, 10) + ((n - 1) / 10 + 1) + ((n - 1) / 100 + 1) + ((n - 1) / 1000 + 1)
else:
return 0
def _check_largeinv_import(self, new_inv, nhosts, nhosts_inactive=0):
self._start_time = time.time()
inv_file = os.path.join(os.path.dirname(__file__), 'data', 'largeinv.py')
ngroups = self._get_ngroups_for_nhosts(nhosts)
os.environ['NHOSTS'] = str(nhosts)
result, stdout, stderr = self.run_command('inventory_import',
inventory_id=new_inv.pk,
source=inv_file,
overwrite=True, verbosity=0)
self.assertEqual(result, None, stdout + stderr)
# Check that inventory is populated as expected within a reasonable
# amount of time. Computed fields should also be updated.
new_inv = Inventory.objects.get(pk=new_inv.pk)
self.assertEqual(new_inv.hosts.filter(active=True).count(), nhosts)
self.assertEqual(new_inv.groups.filter(active=True).count(), ngroups)
self.assertEqual(new_inv.hosts.filter(active=False).count(), nhosts_inactive)
self.assertEqual(new_inv.total_hosts, nhosts)
self.assertEqual(new_inv.total_groups, ngroups)
self.assertElapsedLessThan(30)
def test_large_inventory_file(self):
new_inv = self.organizations[0].inventories.create(name='largeinv')
self.assertEqual(new_inv.hosts.count(), 0)
self.assertEqual(new_inv.groups.count(), 0)
settings.DEBUG = True
nhosts = 2000
# Test initial import into empty inventory.
self._check_largeinv_import(new_inv, nhosts, 0)
# Test re-importing and overwriting.
self._check_largeinv_import(new_inv, nhosts, 0)
# Test re-importing with only half as many hosts.
self._check_largeinv_import(new_inv, nhosts / 2, nhosts / 2)
# Test re-importing that clears all hosts.
self._check_largeinv_import(new_inv, 0, nhosts)

63
awx/main/tests/data/largeinv.py Executable file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python
# Python
import json
import optparse
import os
nhosts = int(os.environ.get('NHOSTS', 100))
inv_list = {
'_meta': {
'hostvars': {},
},
}
for n in xrange(nhosts):
hostname = 'host-%08d.example.com' % n
group_evens_odds = 'evens.example.com' if n % 2 == 0 else 'odds.example.com'
group_threes = 'threes.example.com' if n % 3 == 0 else ''
group_fours = 'fours.example.com' if n % 4 == 0 else ''
group_fives = 'fives.example.com' if n % 5 == 0 else ''
group_sixes = 'sixes.example.com' if n % 6 == 0 else ''
group_sevens = 'sevens.example.com' if n % 7 == 0 else ''
group_eights = 'eights.example.com' if n % 8 == 0 else ''
group_nines = 'nines.example.com' if n % 9 == 0 else ''
group_tens = 'tens.example.com' if n % 10 == 0 else ''
group_by_10s = 'group-%07dX.example.com' % (n / 10)
group_by_100s = 'group-%06dXX.example.com' % (n / 100)
group_by_1000s = 'group-%05dXXX.example.com' % (n / 1000)
for group in [group_evens_odds, group_threes, group_fours, group_fives, group_sixes, group_sevens, group_eights, group_nines, group_tens, group_by_10s]:
if not group:
continue
if group in inv_list:
inv_list[group]['hosts'].append(hostname)
else:
inv_list[group] = {'hosts': [hostname], 'children': [], 'vars': {'group_prefix': group.split('.')[0]}}
if group_by_1000s not in inv_list:
inv_list[group_by_1000s] = {'hosts': [], 'children': [], 'vars': {'group_prefix': group_by_1000s.split('.')[0]}}
if group_by_100s not in inv_list:
inv_list[group_by_100s] = {'hosts': [], 'children': [], 'vars': {'group_prefix': group_by_100s.split('.')[0]}}
if group_by_100s not in inv_list[group_by_1000s]['children']:
inv_list[group_by_1000s]['children'].append(group_by_100s)
if group_by_10s not in inv_list[group_by_100s]['children']:
inv_list[group_by_100s]['children'].append(group_by_10s)
inv_list['_meta']['hostvars'][hostname] = {
'ansible_ssh_user': 'example',
'ansible_connection': 'local',
'host_prefix': hostname.split('.')[0],
'host_id': n,
}
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option('--list', action='store_true', dest='list')
parser.add_option('--host', dest='hostname', default='')
options, args = parser.parse_args()
if options.list:
print json.dumps(inv_list, indent=4)
elif options.hostname:
print json.dumps(inv_list['_meta']['hostvars'][options.hostname], indent=4)
else:
print json.dumps({}, indent=4)

View File

@@ -614,7 +614,7 @@ class InventoryTest(BaseTest):
# data used for testing listing all hosts that are transitive members of a group # data used for testing listing all hosts that are transitive members of a group
g2 = Group.objects.get(name='web4') g2 = Group.objects.get(name='web4')
nh = Host.objects.create(name='newhost.example.com', inventory=inva, nh = Host.objects.create(name='newhost.example.com', inventory=g2.inventory,
created_by=self.super_django_user) created_by=self.super_django_user)
g2.hosts.add(nh) g2.hosts.add(nh)
g2.save() g2.save()