Some fixes for line numbering, and fixes for license error handling (#8)

* Change handling of error cases to global post_run_hook
* handle license errors correctly again
* Fix some issues with line ordering from the custom logger thing
* Remove debug log statement
* Use PermissionDenied for license errors
* More elegant handling of line initialization

Update tests to new exception type

Catch all save errors, fix timing offset bug

Fix license error handling inside import command
This commit is contained in:
Alan Rominger
2020-11-18 20:55:26 -05:00
parent ec93af4ba8
commit d6e84b54c9
4 changed files with 145 additions and 151 deletions

View File

@@ -30,3 +30,10 @@ class _AwxTaskError():
AwxTaskError = _AwxTaskError() AwxTaskError = _AwxTaskError()
class PostRunError(Exception):
def __init__(self, msg, status='failed', tb=''):
self.status = status
self.tb = tb
super(PostRunError, self).__init__(msg)

View File

@@ -19,6 +19,9 @@ from django.core.management.base import BaseCommand, CommandError
from django.db import connection, transaction from django.db import connection, transaction
from django.utils.encoding import smart_text from django.utils.encoding import smart_text
# DRF error class to distinguish license exceptions
from rest_framework.exceptions import PermissionDenied
# AWX inventory imports # AWX inventory imports
from awx.main.models.inventory import ( from awx.main.models.inventory import (
Inventory, Inventory,
@@ -839,9 +842,9 @@ class Command(BaseCommand):
source_vars = self.all_group.variables source_vars = self.all_group.variables
remote_license_type = source_vars.get('tower_metadata', {}).get('license_type', None) remote_license_type = source_vars.get('tower_metadata', {}).get('license_type', None)
if remote_license_type is None: if remote_license_type is None:
raise CommandError('Unexpected Error: Tower inventory plugin missing needed metadata!') raise PermissionDenied('Unexpected Error: Tower inventory plugin missing needed metadata!')
if local_license_type != remote_license_type: if local_license_type != remote_license_type:
raise CommandError('Tower server licenses must match: source: {} local: {}'.format( raise PermissionDenied('Tower server licenses must match: source: {} local: {}'.format(
remote_license_type, local_license_type remote_license_type, local_license_type
)) ))
@@ -850,7 +853,7 @@ class Command(BaseCommand):
local_license_type = license_info.get('license_type', 'UNLICENSED') local_license_type = license_info.get('license_type', 'UNLICENSED')
if local_license_type == 'UNLICENSED': if local_license_type == 'UNLICENSED':
logger.error(LICENSE_NON_EXISTANT_MESSAGE) logger.error(LICENSE_NON_EXISTANT_MESSAGE)
raise CommandError('No license found!') raise PermissionDenied('No license found!')
elif local_license_type == 'open': elif local_license_type == 'open':
return return
available_instances = license_info.get('available_instances', 0) available_instances = license_info.get('available_instances', 0)
@@ -861,7 +864,7 @@ class Command(BaseCommand):
if time_remaining <= 0: if time_remaining <= 0:
if hard_error: if hard_error:
logger.error(LICENSE_EXPIRED_MESSAGE) logger.error(LICENSE_EXPIRED_MESSAGE)
raise CommandError("License has expired!") raise PermissionDenied("License has expired!")
else: else:
logger.warning(LICENSE_EXPIRED_MESSAGE) logger.warning(LICENSE_EXPIRED_MESSAGE)
# special check for tower-type inventory sources # special check for tower-type inventory sources
@@ -878,7 +881,7 @@ class Command(BaseCommand):
} }
if hard_error: if hard_error:
logger.error(LICENSE_MESSAGE % d) logger.error(LICENSE_MESSAGE % d)
raise CommandError('License count exceeded!') raise PermissionDenied('License count exceeded!')
else: else:
logger.warning(LICENSE_MESSAGE % d) logger.warning(LICENSE_MESSAGE % d)
@@ -893,7 +896,7 @@ class Command(BaseCommand):
active_count = Host.objects.org_active_count(org.id) active_count = Host.objects.org_active_count(org.id)
if active_count > org.max_hosts: if active_count > org.max_hosts:
raise CommandError('Host limit for organization exceeded!') raise PermissionDenied('Host limit for organization exceeded!')
def mark_license_failure(self, save=True): def mark_license_failure(self, save=True):
self.inventory_update.license_error = True self.inventory_update.license_error = True
@@ -958,7 +961,17 @@ class Command(BaseCommand):
).load() ).load()
logger.debug('Finished loading from source: %s', source) logger.debug('Finished loading from source: %s', source)
status, tb, exc = self.perform_update(options, data, inventory_update)
status, tb, exc = 'error', '', None
try:
self.perform_update(options, data, inventory_update)
status = 'successful'
except Exception as e:
exc = e
if isinstance(e, KeyboardInterrupt):
status = 'canceled'
else:
tb = traceback.format_exc()
with ignore_inventory_computed_fields(): with ignore_inventory_computed_fields():
inventory_update = InventoryUpdate.objects.get(pk=inventory_update.pk) inventory_update = InventoryUpdate.objects.get(pk=inventory_update.pk)
@@ -1017,119 +1030,106 @@ class Command(BaseCommand):
try: try:
self.check_license() self.check_license()
except CommandError as e: except PermissionDenied as e:
self.mark_license_failure(save=True) self.mark_license_failure(save=True)
raise e raise e
try: try:
# Check the per-org host limits # Check the per-org host limits
self.check_org_host_limit() self.check_org_host_limit()
except CommandError as e: except PermissionDenied as e:
self.mark_org_limits_failure(save=True) self.mark_org_limits_failure(save=True)
raise e raise e
status, tb, exc = 'error', '', None if settings.SQL_DEBUG:
try: queries_before = len(connection.queries)
if settings.SQL_DEBUG:
queries_before = len(connection.queries)
# Update inventory update for this command line invocation. # Update inventory update for this command line invocation.
with ignore_inventory_computed_fields(): with ignore_inventory_computed_fields():
iu = self.inventory_update # TODO: move this to before perform_update
if iu.status != 'running': iu = self.inventory_update
with transaction.atomic(): if iu.status != 'running':
self.inventory_update.status = 'running' with transaction.atomic():
self.inventory_update.save() self.inventory_update.status = 'running'
self.inventory_update.save()
logger.info('Processing JSON output...') logger.info('Processing JSON output...')
inventory = MemInventory( inventory = MemInventory(
group_filter_re=self.group_filter_re, host_filter_re=self.host_filter_re) group_filter_re=self.group_filter_re, host_filter_re=self.host_filter_re)
inventory = dict_to_mem_data(data, inventory=inventory) inventory = dict_to_mem_data(data, inventory=inventory)
logger.info('Loaded %d groups, %d hosts', len(inventory.all_group.all_groups), logger.info('Loaded %d groups, %d hosts', len(inventory.all_group.all_groups),
len(inventory.all_group.all_hosts)) len(inventory.all_group.all_hosts))
if self.exclude_empty_groups: if self.exclude_empty_groups:
inventory.delete_empty_groups() inventory.delete_empty_groups()
self.all_group = inventory.all_group self.all_group = inventory.all_group
if settings.DEBUG: if settings.DEBUG:
# depending on inventory source, this output can be # depending on inventory source, this output can be
# *exceedingly* verbose - crawling a deeply nested # *exceedingly* verbose - crawling a deeply nested
# inventory/group data structure and printing metadata about # inventory/group data structure and printing metadata about
# each host and its memberships # each host and its memberships
# #
# it's easy for this scale of data to overwhelm pexpect, # it's easy for this scale of data to overwhelm pexpect,
# (and it's likely only useful for purposes of debugging the # (and it's likely only useful for purposes of debugging the
# actual inventory import code), so only print it if we have to: # actual inventory import code), so only print it if we have to:
# https://github.com/ansible/ansible-tower/issues/7414#issuecomment-321615104 # https://github.com/ansible/ansible-tower/issues/7414#issuecomment-321615104
self.all_group.debug_tree() self.all_group.debug_tree()
with batch_role_ancestor_rebuilding(): with batch_role_ancestor_rebuilding():
# If using with transaction.atomic() with try ... catch, # If using with transaction.atomic() with try ... catch,
# with transaction.atomic() must be inside the try section of the code as per Django docs # with transaction.atomic() must be inside the try section of the code as per Django docs
try: try:
# Ensure that this is managed as an atomic SQL transaction, # Ensure that this is managed as an atomic SQL transaction,
# and thus properly rolled back if there is an issue. # and thus properly rolled back if there is an issue.
with transaction.atomic(): with transaction.atomic():
# Merge/overwrite inventory into database. # Merge/overwrite inventory into database.
if settings.SQL_DEBUG: if settings.SQL_DEBUG:
logger.warning('loading into database...') logger.warning('loading into database...')
with ignore_inventory_computed_fields(): with ignore_inventory_computed_fields():
if getattr(settings, 'ACTIVITY_STREAM_ENABLED_FOR_INVENTORY_SYNC', True): if getattr(settings, 'ACTIVITY_STREAM_ENABLED_FOR_INVENTORY_SYNC', True):
self.load_into_database()
else:
with disable_activity_stream():
self.load_into_database() self.load_into_database()
else:
with disable_activity_stream():
self.load_into_database()
if settings.SQL_DEBUG:
queries_before2 = len(connection.queries)
self.inventory.update_computed_fields()
if settings.SQL_DEBUG: if settings.SQL_DEBUG:
logger.warning('update computed fields took %d queries', queries_before2 = len(connection.queries)
len(connection.queries) - queries_before2) self.inventory.update_computed_fields()
if settings.SQL_DEBUG:
logger.warning('update computed fields took %d queries',
len(connection.queries) - queries_before2)
# Check if the license is valid. # Check if the license is valid.
# If the license is not valid, a CommandError will be thrown, # If the license is not valid, a CommandError will be thrown,
# and inventory update will be marked as invalid. # and inventory update will be marked as invalid.
# with transaction.atomic() will roll back the changes. # with transaction.atomic() will roll back the changes.
license_fail = True license_fail = True
self.check_license() self.check_license()
# Check the per-org host limits # Check the per-org host limits
license_fail = False license_fail = False
self.check_org_host_limit() self.check_org_host_limit()
except CommandError as e: except PermissionDenied as e:
if license_fail: if license_fail:
self.mark_license_failure(save=True) self.mark_license_failure(save=True)
else:
self.mark_org_limits_failure(save=True)
raise e
if settings.SQL_DEBUG:
logger.warning('Inventory import completed for %s in %0.1fs',
self.inventory_source.name, time.time() - begin)
else: else:
logger.info('Inventory import completed for %s in %0.1fs', self.mark_org_limits_failure(save=True)
self.inventory_source.name, time.time() - begin) raise e
status = 'successful'
# If we're in debug mode, then log the queries and time
# used to do the operation.
if settings.SQL_DEBUG: if settings.SQL_DEBUG:
queries_this_import = connection.queries[queries_before:] logger.warning('Inventory import completed for %s in %0.1fs',
sqltime = sum(float(x['time']) for x in queries_this_import) self.inventory_source.name, time.time() - begin)
logger.warning('Inventory import required %d queries '
'taking %0.3fs', len(queries_this_import),
sqltime)
except Exception as e:
if isinstance(e, KeyboardInterrupt):
status = 'canceled'
exc = e
elif isinstance(e, CommandError):
exc = e
else: else:
tb = traceback.format_exc() logger.info('Inventory import completed for %s in %0.1fs',
exc = e self.inventory_source.name, time.time() - begin)
return status, tb, exc # If we're in debug mode, then log the queries and time
# used to do the operation.
if settings.SQL_DEBUG:
queries_this_import = connection.queries[queries_before:]
sqltime = sum(float(x['time']) for x in queries_this_import)
logger.warning('Inventory import required %d queries '
'taking %0.3fs', len(queries_this_import),
sqltime)

View File

@@ -63,7 +63,7 @@ from awx.main.models import (
build_safe_env, enforce_bigint_pk_migration build_safe_env, enforce_bigint_pk_migration
) )
from awx.main.constants import ACTIVE_STATES from awx.main.constants import ACTIVE_STATES
from awx.main.exceptions import AwxTaskError from awx.main.exceptions import AwxTaskError, PostRunError
from awx.main.queue import CallbackQueueDispatcher from awx.main.queue import CallbackQueueDispatcher
from awx.main.isolated import manager as isolated_manager from awx.main.isolated import manager as isolated_manager
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
@@ -1229,14 +1229,7 @@ class BaseTask(object):
# ansible-inventory and the awx.main.commands.inventory_import # ansible-inventory and the awx.main.commands.inventory_import
# logger # logger
if isinstance(self, RunInventoryUpdate): if isinstance(self, RunInventoryUpdate):
if not getattr(self, 'end_line', None): self.end_line = event_data['end_line']
# this is the very first event
# note the end_line
self.end_line = event_data['end_line']
else:
num_lines = event_data['end_line'] - event_data['start_line']
event_data['start_line'] = self.end_line + 1
self.end_line = event_data['end_line'] = event_data['start_line'] + num_lines
if event_data.get(self.event_data_key, None): if event_data.get(self.event_data_key, None):
if self.event_data_key != 'job_id': if self.event_data_key != 'job_id':
@@ -1534,6 +1527,12 @@ class BaseTask(object):
try: try:
self.post_run_hook(self.instance, status) self.post_run_hook(self.instance, status)
except PostRunError as exc:
if status == 'successful':
status = exc.status
extra_update_fields['job_explanation'] = exc.args[0]
if exc.tb:
extra_update_fields['result_traceback'] = exc.tb
except Exception: except Exception:
logger.exception('{} Post run hook errored.'.format(self.instance.log_format)) logger.exception('{} Post run hook errored.'.format(self.instance.log_format))
@@ -2744,27 +2743,28 @@ class RunInventoryUpdate(BaseTask):
# Mock ansible-runner events # Mock ansible-runner events
class CallbackHandler(logging.Handler): class CallbackHandler(logging.Handler):
def __init__(self, event_handler, cancel_callback, job_timeout, verbosity, def __init__(self, event_handler, cancel_callback, job_timeout, verbosity,
counter=0, **kwargs): start_time=None, counter=0, initial_line=0, **kwargs):
self.event_handler = event_handler self.event_handler = event_handler
self.cancel_callback = cancel_callback self.cancel_callback = cancel_callback
self.job_timeout = job_timeout self.job_timeout = job_timeout
self.job_start = time.time() if start_time is None:
self.job_start = now()
else:
self.job_start = start_time
self.last_check = self.job_start self.last_check = self.job_start
# TODO: we do not have events from the ansible-inventory process
# so there is no way to know initial counter of start line
self.counter = counter self.counter = counter
self.skip_level = [logging.WARNING, logging.INFO, logging.DEBUG, 0][verbosity] self.skip_level = [logging.WARNING, logging.INFO, logging.DEBUG, 0][verbosity]
self._start_line = 0 self._start_line = initial_line
super(CallbackHandler, self).__init__(**kwargs) super(CallbackHandler, self).__init__(**kwargs)
def emit(self, record): def emit(self, record):
this_time = time.time() this_time = now()
if this_time - self.last_check > 0.5: if (this_time - self.last_check).total_seconds() > 0.5:
self.last_check = this_time self.last_check = this_time
if self.cancel_callback(): if self.cancel_callback():
raise RuntimeError('Inventory update has been canceled') raise PostRunError('Inventory update has been canceled', status='canceled')
if self.job_timeout and ((this_time - self.job_start) > self.job_timeout): if self.job_timeout and ((this_time - self.job_start).total_seconds() > self.job_timeout):
raise RuntimeError('Inventory update has timed out') raise PostRunError('Inventory update has timed out', status='canceled')
# skip logging for low severity logs # skip logging for low severity logs
if record.levelno < self.skip_level: if record.levelno < self.skip_level:
@@ -2772,16 +2772,16 @@ class RunInventoryUpdate(BaseTask):
self.counter += 1 self.counter += 1
msg = self.format(record) msg = self.format(record)
n_lines = msg.strip().count('\n') # don't count new-lines at boundry of text n_lines = len(msg.strip().split('\n')) # don't count new-lines at boundry of text
dispatch_data = dict( dispatch_data = dict(
created=now().isoformat(), created=now().isoformat(),
event='verbose', event='verbose',
counter=self.counter, counter=self.counter,
stdout=msg + '\n', stdout=msg,
start_line=self._start_line, start_line=self._start_line,
end_line=self._start_line + n_lines end_line=self._start_line + n_lines
) )
self._start_line += n_lines + 1 self._start_line += n_lines
self.event_handler(dispatch_data) self.event_handler(dispatch_data)
@@ -2789,7 +2789,8 @@ class RunInventoryUpdate(BaseTask):
self.event_handler, self.cancel_callback, self.event_handler, self.cancel_callback,
verbosity=inventory_update.verbosity, verbosity=inventory_update.verbosity,
job_timeout=self.get_instance_timeout(self.instance), job_timeout=self.get_instance_timeout(self.instance),
counter=self.event_ct start_time=inventory_update.started,
counter=self.event_ct, initial_line=self.end_line
) )
inv_logger = logging.getLogger('awx.main.commands.inventory_import') inv_logger = logging.getLogger('awx.main.commands.inventory_import')
handler.formatter = inv_logger.handlers[0].formatter handler.formatter = inv_logger.handlers[0].formatter
@@ -2797,36 +2798,19 @@ class RunInventoryUpdate(BaseTask):
from awx.main.management.commands.inventory_import import Command as InventoryImportCommand from awx.main.management.commands.inventory_import import Command as InventoryImportCommand
cmd = InventoryImportCommand() cmd = InventoryImportCommand()
exc = None
try: try:
# note that we are only using the management command to # save the inventory data to database.
# save the inventory data to the database. # canceling exceptions will be handled in the global post_run_hook
# we are not asking it to actually fetch hosts / groups. cmd.perform_update(options, data, inventory_update)
# that work was taken care of earlier, when except PermissionDenied as exc:
# BaseTask.run called ansible-inventory (by way of ansible-runner) logger.exception('License error saving {} content'.format(inventory_update.log_format))
# for us. raise PostRunError(str(exc), status='error')
save_status, tb, exc = cmd.perform_update(options, data, inventory_update) except Exception:
except Exception as raw_exc: logger.exception('Exception saving {} content, rolling back changes.'.format(
if exc is None: inventory_update.log_format))
exc = raw_exc raise PostRunError(
# Ignore license errors specifically 'Error occured while saving inventory data, see traceback or server logs',
if 'Host limit for organization' not in str(exc) and 'License' not in str(exc): status='error', tb=traceback.format_exc())
raise raw_exc
model_updates = {}
if save_status != status:
model_updates['status'] = save_status
if tb:
model_updates['result_traceback'] = tb
if model_updates:
logger.info('{} had problems saving to database with {}'.format(
inventory_update.log_format, ', '.join(list(model_updates.keys()))
))
model_updates['job_explanation'] = 'Update failed to save all changes to database properly.'
if exc:
model_updates['job_explanation'] += ' {}'.format(exc)
self.update_model(inventory_update.pk, **model_updates)
@task(queue=get_local_queuename) @task(queue=get_local_queuename)

View File

@@ -9,6 +9,9 @@ import os
# Django # Django
from django.core.management.base import CommandError from django.core.management.base import CommandError
# for license errors
from rest_framework.exceptions import PermissionDenied
# AWX # AWX
from awx.main.management.commands import inventory_import from awx.main.management.commands import inventory_import
from awx.main.models import Inventory, Host, Group, InventorySource from awx.main.models import Inventory, Host, Group, InventorySource
@@ -322,6 +325,6 @@ def test_tower_version_compare():
"version": "2.0.1-1068-g09684e2c41" "version": "2.0.1-1068-g09684e2c41"
} }
} }
with pytest.raises(CommandError): with pytest.raises(PermissionDenied):
cmd.remote_tower_license_compare('very_supported') cmd.remote_tower_license_compare('very_supported')
cmd.remote_tower_license_compare('open') cmd.remote_tower_license_compare('open')