diff --git a/awx/main/exceptions.py b/awx/main/exceptions.py index 8aadfd80b0..64cbc94783 100644 --- a/awx/main/exceptions.py +++ b/awx/main/exceptions.py @@ -30,3 +30,10 @@ class _AwxTaskError(): AwxTaskError = _AwxTaskError() + + +class PostRunError(Exception): + def __init__(self, msg, status='failed', tb=''): + self.status = status + self.tb = tb + super(PostRunError, self).__init__(msg) diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index 2636cf8d10..2179faad6b 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -19,6 +19,9 @@ from django.core.management.base import BaseCommand, CommandError from django.db import connection, transaction from django.utils.encoding import smart_text +# DRF error class to distinguish license exceptions +from rest_framework.exceptions import PermissionDenied + # AWX inventory imports from awx.main.models.inventory import ( Inventory, @@ -839,9 +842,9 @@ class Command(BaseCommand): source_vars = self.all_group.variables remote_license_type = source_vars.get('tower_metadata', {}).get('license_type', None) if remote_license_type is None: - raise CommandError('Unexpected Error: Tower inventory plugin missing needed metadata!') + raise PermissionDenied('Unexpected Error: Tower inventory plugin missing needed metadata!') if local_license_type != remote_license_type: - raise CommandError('Tower server licenses must match: source: {} local: {}'.format( + raise PermissionDenied('Tower server licenses must match: source: {} local: {}'.format( remote_license_type, local_license_type )) @@ -850,7 +853,7 @@ class Command(BaseCommand): local_license_type = license_info.get('license_type', 'UNLICENSED') if local_license_type == 'UNLICENSED': logger.error(LICENSE_NON_EXISTANT_MESSAGE) - raise CommandError('No license found!') + raise PermissionDenied('No license found!') elif local_license_type == 'open': return available_instances = license_info.get('available_instances', 0) @@ -861,7 +864,7 @@ class Command(BaseCommand): if time_remaining <= 0: if hard_error: logger.error(LICENSE_EXPIRED_MESSAGE) - raise CommandError("License has expired!") + raise PermissionDenied("License has expired!") else: logger.warning(LICENSE_EXPIRED_MESSAGE) # special check for tower-type inventory sources @@ -878,7 +881,7 @@ class Command(BaseCommand): } if hard_error: logger.error(LICENSE_MESSAGE % d) - raise CommandError('License count exceeded!') + raise PermissionDenied('License count exceeded!') else: logger.warning(LICENSE_MESSAGE % d) @@ -893,7 +896,7 @@ class Command(BaseCommand): active_count = Host.objects.org_active_count(org.id) if active_count > org.max_hosts: - raise CommandError('Host limit for organization exceeded!') + raise PermissionDenied('Host limit for organization exceeded!') def mark_license_failure(self, save=True): self.inventory_update.license_error = True @@ -958,7 +961,17 @@ class Command(BaseCommand): ).load() logger.debug('Finished loading from source: %s', source) - status, tb, exc = self.perform_update(options, data, inventory_update) + + status, tb, exc = 'error', '', None + try: + self.perform_update(options, data, inventory_update) + status = 'successful' + except Exception as e: + exc = e + if isinstance(e, KeyboardInterrupt): + status = 'canceled' + else: + tb = traceback.format_exc() with ignore_inventory_computed_fields(): inventory_update = InventoryUpdate.objects.get(pk=inventory_update.pk) @@ -1017,119 +1030,106 @@ class Command(BaseCommand): try: self.check_license() - except CommandError as e: + except PermissionDenied as e: self.mark_license_failure(save=True) raise e try: # Check the per-org host limits self.check_org_host_limit() - except CommandError as e: + except PermissionDenied as e: self.mark_org_limits_failure(save=True) raise e - status, tb, exc = 'error', '', None - try: - if settings.SQL_DEBUG: - queries_before = len(connection.queries) + if settings.SQL_DEBUG: + queries_before = len(connection.queries) - # Update inventory update for this command line invocation. - with ignore_inventory_computed_fields(): - iu = self.inventory_update - if iu.status != 'running': - with transaction.atomic(): - self.inventory_update.status = 'running' - self.inventory_update.save() + # Update inventory update for this command line invocation. + with ignore_inventory_computed_fields(): + # TODO: move this to before perform_update + iu = self.inventory_update + if iu.status != 'running': + with transaction.atomic(): + self.inventory_update.status = 'running' + self.inventory_update.save() - logger.info('Processing JSON output...') - inventory = MemInventory( - group_filter_re=self.group_filter_re, host_filter_re=self.host_filter_re) - inventory = dict_to_mem_data(data, inventory=inventory) + logger.info('Processing JSON output...') + inventory = MemInventory( + group_filter_re=self.group_filter_re, host_filter_re=self.host_filter_re) + inventory = dict_to_mem_data(data, inventory=inventory) - logger.info('Loaded %d groups, %d hosts', len(inventory.all_group.all_groups), - len(inventory.all_group.all_hosts)) + logger.info('Loaded %d groups, %d hosts', len(inventory.all_group.all_groups), + len(inventory.all_group.all_hosts)) - if self.exclude_empty_groups: - inventory.delete_empty_groups() + if self.exclude_empty_groups: + inventory.delete_empty_groups() - self.all_group = inventory.all_group + self.all_group = inventory.all_group - if settings.DEBUG: - # depending on inventory source, this output can be - # *exceedingly* verbose - crawling a deeply nested - # inventory/group data structure and printing metadata about - # each host and its memberships - # - # it's easy for this scale of data to overwhelm pexpect, - # (and it's likely only useful for purposes of debugging the - # actual inventory import code), so only print it if we have to: - # https://github.com/ansible/ansible-tower/issues/7414#issuecomment-321615104 - self.all_group.debug_tree() + if settings.DEBUG: + # depending on inventory source, this output can be + # *exceedingly* verbose - crawling a deeply nested + # inventory/group data structure and printing metadata about + # each host and its memberships + # + # it's easy for this scale of data to overwhelm pexpect, + # (and it's likely only useful for purposes of debugging the + # actual inventory import code), so only print it if we have to: + # https://github.com/ansible/ansible-tower/issues/7414#issuecomment-321615104 + self.all_group.debug_tree() - with batch_role_ancestor_rebuilding(): - # If using with transaction.atomic() with try ... catch, - # with transaction.atomic() must be inside the try section of the code as per Django docs - try: - # Ensure that this is managed as an atomic SQL transaction, - # and thus properly rolled back if there is an issue. - with transaction.atomic(): - # Merge/overwrite inventory into database. - if settings.SQL_DEBUG: - logger.warning('loading into database...') - with ignore_inventory_computed_fields(): - if getattr(settings, 'ACTIVITY_STREAM_ENABLED_FOR_INVENTORY_SYNC', True): + with batch_role_ancestor_rebuilding(): + # If using with transaction.atomic() with try ... catch, + # with transaction.atomic() must be inside the try section of the code as per Django docs + try: + # Ensure that this is managed as an atomic SQL transaction, + # and thus properly rolled back if there is an issue. + with transaction.atomic(): + # Merge/overwrite inventory into database. + if settings.SQL_DEBUG: + logger.warning('loading into database...') + with ignore_inventory_computed_fields(): + if getattr(settings, 'ACTIVITY_STREAM_ENABLED_FOR_INVENTORY_SYNC', True): + self.load_into_database() + else: + with disable_activity_stream(): self.load_into_database() - else: - with disable_activity_stream(): - self.load_into_database() - if settings.SQL_DEBUG: - queries_before2 = len(connection.queries) - self.inventory.update_computed_fields() if settings.SQL_DEBUG: - logger.warning('update computed fields took %d queries', - len(connection.queries) - queries_before2) + queries_before2 = len(connection.queries) + self.inventory.update_computed_fields() + if settings.SQL_DEBUG: + logger.warning('update computed fields took %d queries', + len(connection.queries) - queries_before2) - # Check if the license is valid. - # If the license is not valid, a CommandError will be thrown, - # and inventory update will be marked as invalid. - # with transaction.atomic() will roll back the changes. - license_fail = True - self.check_license() + # Check if the license is valid. + # If the license is not valid, a CommandError will be thrown, + # and inventory update will be marked as invalid. + # with transaction.atomic() will roll back the changes. + license_fail = True + self.check_license() - # Check the per-org host limits - license_fail = False - self.check_org_host_limit() - except CommandError as e: - if license_fail: - self.mark_license_failure(save=True) - else: - self.mark_org_limits_failure(save=True) - raise e - - if settings.SQL_DEBUG: - logger.warning('Inventory import completed for %s in %0.1fs', - self.inventory_source.name, time.time() - begin) + # Check the per-org host limits + license_fail = False + self.check_org_host_limit() + except PermissionDenied as e: + if license_fail: + self.mark_license_failure(save=True) else: - logger.info('Inventory import completed for %s in %0.1fs', - self.inventory_source.name, time.time() - begin) - status = 'successful' + self.mark_org_limits_failure(save=True) + raise e - # If we're in debug mode, then log the queries and time - # used to do the operation. if settings.SQL_DEBUG: - queries_this_import = connection.queries[queries_before:] - sqltime = sum(float(x['time']) for x in queries_this_import) - logger.warning('Inventory import required %d queries ' - 'taking %0.3fs', len(queries_this_import), - sqltime) - except Exception as e: - if isinstance(e, KeyboardInterrupt): - status = 'canceled' - exc = e - elif isinstance(e, CommandError): - exc = e + logger.warning('Inventory import completed for %s in %0.1fs', + self.inventory_source.name, time.time() - begin) else: - tb = traceback.format_exc() - exc = e + logger.info('Inventory import completed for %s in %0.1fs', + self.inventory_source.name, time.time() - begin) - return status, tb, exc + # If we're in debug mode, then log the queries and time + # used to do the operation. + if settings.SQL_DEBUG: + queries_this_import = connection.queries[queries_before:] + sqltime = sum(float(x['time']) for x in queries_this_import) + logger.warning('Inventory import required %d queries ' + 'taking %0.3fs', len(queries_this_import), + sqltime) diff --git a/awx/main/tasks.py b/awx/main/tasks.py index d8a22440f8..238f002562 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -63,7 +63,7 @@ from awx.main.models import ( build_safe_env, enforce_bigint_pk_migration ) from awx.main.constants import ACTIVE_STATES -from awx.main.exceptions import AwxTaskError +from awx.main.exceptions import AwxTaskError, PostRunError from awx.main.queue import CallbackQueueDispatcher from awx.main.isolated import manager as isolated_manager from awx.main.dispatch.publish import task @@ -1229,14 +1229,7 @@ class BaseTask(object): # ansible-inventory and the awx.main.commands.inventory_import # logger if isinstance(self, RunInventoryUpdate): - if not getattr(self, 'end_line', None): - # this is the very first event - # note the end_line - self.end_line = event_data['end_line'] - else: - num_lines = event_data['end_line'] - event_data['start_line'] - event_data['start_line'] = self.end_line + 1 - self.end_line = event_data['end_line'] = event_data['start_line'] + num_lines + self.end_line = event_data['end_line'] if event_data.get(self.event_data_key, None): if self.event_data_key != 'job_id': @@ -1534,6 +1527,12 @@ class BaseTask(object): try: self.post_run_hook(self.instance, status) + except PostRunError as exc: + if status == 'successful': + status = exc.status + extra_update_fields['job_explanation'] = exc.args[0] + if exc.tb: + extra_update_fields['result_traceback'] = exc.tb except Exception: logger.exception('{} Post run hook errored.'.format(self.instance.log_format)) @@ -2744,27 +2743,28 @@ class RunInventoryUpdate(BaseTask): # Mock ansible-runner events class CallbackHandler(logging.Handler): def __init__(self, event_handler, cancel_callback, job_timeout, verbosity, - counter=0, **kwargs): + start_time=None, counter=0, initial_line=0, **kwargs): self.event_handler = event_handler self.cancel_callback = cancel_callback self.job_timeout = job_timeout - self.job_start = time.time() + if start_time is None: + self.job_start = now() + else: + self.job_start = start_time self.last_check = self.job_start - # TODO: we do not have events from the ansible-inventory process - # so there is no way to know initial counter of start line self.counter = counter self.skip_level = [logging.WARNING, logging.INFO, logging.DEBUG, 0][verbosity] - self._start_line = 0 + self._start_line = initial_line super(CallbackHandler, self).__init__(**kwargs) def emit(self, record): - this_time = time.time() - if this_time - self.last_check > 0.5: + this_time = now() + if (this_time - self.last_check).total_seconds() > 0.5: self.last_check = this_time if self.cancel_callback(): - raise RuntimeError('Inventory update has been canceled') - if self.job_timeout and ((this_time - self.job_start) > self.job_timeout): - raise RuntimeError('Inventory update has timed out') + raise PostRunError('Inventory update has been canceled', status='canceled') + if self.job_timeout and ((this_time - self.job_start).total_seconds() > self.job_timeout): + raise PostRunError('Inventory update has timed out', status='canceled') # skip logging for low severity logs if record.levelno < self.skip_level: @@ -2772,16 +2772,16 @@ class RunInventoryUpdate(BaseTask): self.counter += 1 msg = self.format(record) - n_lines = msg.strip().count('\n') # don't count new-lines at boundry of text + n_lines = len(msg.strip().split('\n')) # don't count new-lines at boundry of text dispatch_data = dict( created=now().isoformat(), event='verbose', counter=self.counter, - stdout=msg + '\n', + stdout=msg, start_line=self._start_line, end_line=self._start_line + n_lines ) - self._start_line += n_lines + 1 + self._start_line += n_lines self.event_handler(dispatch_data) @@ -2789,7 +2789,8 @@ class RunInventoryUpdate(BaseTask): self.event_handler, self.cancel_callback, verbosity=inventory_update.verbosity, job_timeout=self.get_instance_timeout(self.instance), - counter=self.event_ct + start_time=inventory_update.started, + counter=self.event_ct, initial_line=self.end_line ) inv_logger = logging.getLogger('awx.main.commands.inventory_import') handler.formatter = inv_logger.handlers[0].formatter @@ -2797,36 +2798,19 @@ class RunInventoryUpdate(BaseTask): from awx.main.management.commands.inventory_import import Command as InventoryImportCommand cmd = InventoryImportCommand() - exc = None try: - # note that we are only using the management command to - # save the inventory data to the database. - # we are not asking it to actually fetch hosts / groups. - # that work was taken care of earlier, when - # BaseTask.run called ansible-inventory (by way of ansible-runner) - # for us. - save_status, tb, exc = cmd.perform_update(options, data, inventory_update) - except Exception as raw_exc: - if exc is None: - exc = raw_exc - # Ignore license errors specifically - if 'Host limit for organization' not in str(exc) and 'License' not in str(exc): - raise raw_exc - - model_updates = {} - if save_status != status: - model_updates['status'] = save_status - if tb: - model_updates['result_traceback'] = tb - - if model_updates: - logger.info('{} had problems saving to database with {}'.format( - inventory_update.log_format, ', '.join(list(model_updates.keys())) - )) - model_updates['job_explanation'] = 'Update failed to save all changes to database properly.' - if exc: - model_updates['job_explanation'] += ' {}'.format(exc) - self.update_model(inventory_update.pk, **model_updates) + # save the inventory data to database. + # canceling exceptions will be handled in the global post_run_hook + cmd.perform_update(options, data, inventory_update) + except PermissionDenied as exc: + logger.exception('License error saving {} content'.format(inventory_update.log_format)) + raise PostRunError(str(exc), status='error') + except Exception: + logger.exception('Exception saving {} content, rolling back changes.'.format( + inventory_update.log_format)) + raise PostRunError( + 'Error occured while saving inventory data, see traceback or server logs', + status='error', tb=traceback.format_exc()) @task(queue=get_local_queuename) diff --git a/awx/main/tests/functional/commands/test_inventory_import.py b/awx/main/tests/functional/commands/test_inventory_import.py index 3fe4b92b5a..0500ef197c 100644 --- a/awx/main/tests/functional/commands/test_inventory_import.py +++ b/awx/main/tests/functional/commands/test_inventory_import.py @@ -9,6 +9,9 @@ import os # Django from django.core.management.base import CommandError +# for license errors +from rest_framework.exceptions import PermissionDenied + # AWX from awx.main.management.commands import inventory_import from awx.main.models import Inventory, Host, Group, InventorySource @@ -322,6 +325,6 @@ def test_tower_version_compare(): "version": "2.0.1-1068-g09684e2c41" } } - with pytest.raises(CommandError): + with pytest.raises(PermissionDenied): cmd.remote_tower_license_compare('very_supported') cmd.remote_tower_license_compare('open')