Make ask_mapping a simple class property

from PR feedback of saved launchtime configurations
This commit is contained in:
AlanCoding
2017-12-06 17:08:55 -05:00
parent 98df442ced
commit 72a8854c27
16 changed files with 82 additions and 71 deletions

View File

@@ -3096,7 +3096,7 @@ class WorkflowJobTemplateNodeSerializer(LaunchConfigurationBaseSerializer):
def validate(self, attrs): def validate(self, attrs):
deprecated_fields = {} deprecated_fields = {}
if 'credential' in attrs: if 'credential' in attrs: # TODO: remove when v2 API is deprecated
deprecated_fields['credential'] = attrs.pop('credential') deprecated_fields['credential'] = attrs.pop('credential')
view = self.context.get('view') view = self.context.get('view')
if self.instance is None and ('workflow_job_template' not in attrs or if self.instance is None and ('workflow_job_template' not in attrs or
@@ -3120,7 +3120,7 @@ class WorkflowJobTemplateNodeSerializer(LaunchConfigurationBaseSerializer):
errors.pop('variables_needed_to_start', None) errors.pop('variables_needed_to_start', None)
if errors: if errors:
raise serializers.ValidationError(errors) raise serializers.ValidationError(errors)
if 'credential' in deprecated_fields: if 'credential' in deprecated_fields: # TODO: remove when v2 API is deprecated
cred = deprecated_fields['credential'] cred = deprecated_fields['credential']
attrs['credential'] = cred attrs['credential'] = cred
if cred is not None: if cred is not None:
@@ -3130,7 +3130,7 @@ class WorkflowJobTemplateNodeSerializer(LaunchConfigurationBaseSerializer):
raise PermissionDenied() raise PermissionDenied()
return attrs return attrs
def create(self, validated_data): def create(self, validated_data): # TODO: remove when v2 API is deprecated
deprecated_fields = {} deprecated_fields = {}
if 'credential' in validated_data: if 'credential' in validated_data:
deprecated_fields['credential'] = validated_data.pop('credential') deprecated_fields['credential'] = validated_data.pop('credential')
@@ -3140,7 +3140,7 @@ class WorkflowJobTemplateNodeSerializer(LaunchConfigurationBaseSerializer):
obj.credentials.add(deprecated_fields['credential']) obj.credentials.add(deprecated_fields['credential'])
return obj return obj
def update(self, obj, validated_data): def update(self, obj, validated_data): # TODO: remove when v2 API is deprecated
deprecated_fields = {} deprecated_fields = {}
if 'credential' in validated_data: if 'credential' in validated_data:
deprecated_fields['credential'] = validated_data.pop('credential') deprecated_fields['credential'] = validated_data.pop('credential')
@@ -3438,7 +3438,7 @@ class JobLaunchSerializer(BaseSerializer):
def get_defaults(self, obj): def get_defaults(self, obj):
defaults_dict = {} defaults_dict = {}
for field_name in JobTemplate.ask_mapping.keys(): for field_name in JobTemplate.get_ask_mapping().keys():
if field_name == 'inventory': if field_name == 'inventory':
defaults_dict[field_name] = dict( defaults_dict[field_name] = dict(
name=getattrd(obj, '%s.name' % field_name, None), name=getattrd(obj, '%s.name' % field_name, None),
@@ -3467,7 +3467,7 @@ class JobLaunchSerializer(BaseSerializer):
def validate(self, attrs): def validate(self, attrs):
template = self.context.get('template') template = self.context.get('template')
template._is_manual_launch = True # TODO: hopefully remove this template._is_manual_launch = True # signal to make several error types non-blocking
accepted, rejected, errors = template._accept_or_ignore_job_kwargs(**attrs) accepted, rejected, errors = template._accept_or_ignore_job_kwargs(**attrs)
self._ignored_fields = rejected self._ignored_fields = rejected
@@ -3493,13 +3493,12 @@ class JobLaunchSerializer(BaseSerializer):
passwords = attrs.get('credential_passwords', {}) # get from original attrs passwords = attrs.get('credential_passwords', {}) # get from original attrs
passwords_lacking = [] passwords_lacking = []
for cred in launch_credentials: for cred in launch_credentials:
if cred.passwords_needed: for p in cred.passwords_needed:
for p in cred.passwords_needed: if p not in passwords:
if p not in passwords: passwords_lacking.append(p)
passwords_lacking.append(p) else:
else: accepted.setdefault('credential_passwords', {})
accepted.setdefault('credential_passwords', {}) accepted['credential_passwords'][p] = passwords[p]
accepted['credential_passwords'][p] = passwords[p]
if len(passwords_lacking): if len(passwords_lacking):
errors['passwords_needed_to_start'] = passwords_lacking errors['passwords_needed_to_start'] = passwords_lacking

View File

@@ -616,22 +616,25 @@ class LaunchConfigCredentialsBase(SubListAttachDetachAPIView):
def is_valid_relation(self, parent, sub, created=False): def is_valid_relation(self, parent, sub, created=False):
if not parent.unified_job_template: if not parent.unified_job_template:
return {"msg": _("Cannot assign credential when related template is null.")} return {"msg": _("Cannot assign credential when related template is null.")}
elif self.relationship not in parent.unified_job_template.ask_mapping:
return {"msg": _("Related template cannot accept credentials on launch.")} ask_mapping = parent.unified_job_template.get_ask_mapping()
if self.relationship not in ask_mapping:
return {"msg": _("Related template cannot accept {} on launch.").format(self.relationship)}
elif sub.passwords_needed: elif sub.passwords_needed:
return {"msg": _("Credential that requires user input on launch " return {"msg": _("Credential that requires user input on launch "
"cannot be used in saved launch configuration.")} "cannot be used in saved launch configuration.")}
ask_field_name = parent.unified_job_template.ask_mapping[self.relationship] ask_field_name = ask_mapping[self.relationship]
if not getattr(parent, ask_field_name): if not getattr(parent, ask_field_name):
return {"msg": _("Related template is not configured to accept credentials on launch.")} return {"msg": _("Related template is not configured to accept credentials on launch.")}
elif sub.unique_hash() in [cred.unique_hash() for cred in parent.credentials.all()]: elif sub.unique_hash() in [cred.unique_hash() for cred in parent.credentials.all()]:
return {"msg": _("This launch configuration already provides a {credential_type} credential.".format( return {"msg": _("This launch configuration already provides a {credential_type} credential.").format(
credential_type=sub.unique_hash(display=True)))} credential_type=sub.unique_hash(display=True))}
elif sub.pk in parent.unified_job_template.credentials.values_list('pk', flat=True): elif sub.pk in parent.unified_job_template.credentials.values_list('pk', flat=True):
return {"msg": _("Related template already uses {credential_type} credential.".format( return {"msg": _("Related template already uses {credential_type} credential.").format(
credential_type=sub.name))} credential_type=sub.name)}
# None means there were no validation errors # None means there were no validation errors
return None return None
@@ -2752,7 +2755,7 @@ class JobTemplateLaunch(RetrieveAPIView):
extra_vars.setdefault(v, u'') extra_vars.setdefault(v, u'')
if extra_vars: if extra_vars:
data['extra_vars'] = extra_vars data['extra_vars'] = extra_vars
modified_ask_mapping = JobTemplate.ask_mapping.copy() modified_ask_mapping = JobTemplate.get_ask_mapping()
modified_ask_mapping.pop('extra_vars') modified_ask_mapping.pop('extra_vars')
for field, ask_field_name in modified_ask_mapping.items(): for field, ask_field_name in modified_ask_mapping.items():
if not getattr(obj, ask_field_name): if not getattr(obj, ask_field_name):
@@ -2823,7 +2826,7 @@ class JobTemplateLaunch(RetrieveAPIView):
# If user gave extra_credentials, special case to use exactly # If user gave extra_credentials, special case to use exactly
# the given list without merging with JT credentials # the given list without merging with JT credentials
if key == 'extra_credentials' and prompted_value: if key == 'extra_credentials' and prompted_value:
obj._deprecated_credential_launch = True obj._deprecated_credential_launch = True # signal to not merge credentials
new_credentials.extend(prompted_value) new_credentials.extend(prompted_value)
# combine the list of "new" and the filtered list of "old" # combine the list of "new" and the filtered list of "old"
@@ -2840,14 +2843,12 @@ class JobTemplateLaunch(RetrieveAPIView):
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
obj = self.get_object() obj = self.get_object()
print request.data
try: try:
modern_data, ignored_fields = self.modernize_launch_payload( modern_data, ignored_fields = self.modernize_launch_payload(
data=request.data, obj=obj data=request.data, obj=obj
) )
except ParseError as exc: except ParseError as exc:
print ' args ' + str(exc.args)
return Response(exc.detail, status=status.HTTP_400_BAD_REQUEST) return Response(exc.detail, status=status.HTTP_400_BAD_REQUEST)
serializer = self.serializer_class(data=modern_data, context={'template': obj}) serializer = self.serializer_class(data=modern_data, context={'template': obj})

View File

@@ -1403,7 +1403,7 @@ class JobAccess(BaseAccess):
except JobLaunchConfig.DoesNotExist: except JobLaunchConfig.DoesNotExist:
config = None config = None
# Check if JT execute access (and related prompts) are sufficient # Check if JT execute access (and related prompts) is sufficient
if obj.job_template is not None: if obj.job_template is not None:
if config is None: if config is None:
prompts_access = False prompts_access = False

View File

@@ -767,6 +767,16 @@ class AskForField(models.BooleanField):
""" """
Denotes whether to prompt on launch for another field on the same template Denotes whether to prompt on launch for another field on the same template
""" """
def __init__(self, allows_field='__default__', **kwargs): def __init__(self, allows_field=None, **kwargs):
super(AskForField, self).__init__(**kwargs) super(AskForField, self).__init__(**kwargs)
self.allows_field = allows_field self._allows_field = allows_field
@property
def allows_field(self):
if self._allows_field is None:
try:
return self.name[len('ask_'):-len('_on_launch')]
except AttributeError:
# self.name will be set by the model metaclass, not this field
raise Exception('Corresponding allows_field cannot be accessed until model is initialized.')
return self._allows_field

View File

@@ -7,6 +7,7 @@ import awx.main.fields
from awx.main.migrations import _migration_utils as migration_utils from awx.main.migrations import _migration_utils as migration_utils
from awx.main.migrations._multi_cred import migrate_workflow_cred, migrate_workflow_cred_reverse from awx.main.migrations._multi_cred import migrate_workflow_cred, migrate_workflow_cred_reverse
from awx.main.migrations._scan_jobs import remove_scan_type_nodes
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -69,6 +70,7 @@ class Migration(migrations.Migration):
# Run data migration before removing the old credential field # Run data migration before removing the old credential field
migrations.RunPython(migration_utils.set_current_apps_for_migrations, migrations.RunPython.noop), migrations.RunPython(migration_utils.set_current_apps_for_migrations, migrations.RunPython.noop),
migrations.RunPython(migrate_workflow_cred, migrate_workflow_cred_reverse), migrations.RunPython(migrate_workflow_cred, migrate_workflow_cred_reverse),
migrations.RunPython(remove_scan_type_nodes, migrations.RunPython.noop),
migrations.RemoveField( migrations.RemoveField(
model_name='workflowjobnode', model_name='workflowjobnode',
name='credential', name='credential',

View File

@@ -82,3 +82,21 @@ def _migrate_scan_job_templates(apps):
def migrate_scan_job_templates(apps, schema_editor): def migrate_scan_job_templates(apps, schema_editor):
_migrate_scan_job_templates(apps) _migrate_scan_job_templates(apps)
def remove_scan_type_nodes(apps, schema_editor):
WorkflowJobTemplateNode = apps.get_model('main', 'WorkflowJobTemplateNode')
WorkflowJobNode = apps.get_model('main', 'WorkflowJobNode')
for cls in (WorkflowJobNode, WorkflowJobTemplateNode):
for node in cls.objects.iterator():
prompts = node.char_prompts
if prompts.get('job_type', None) == 'scan':
log_text = '{} set job_type to scan, which was deprecated in 3.2, removing.'.format(cls)
if cls == WorkflowJobNode:
logger.info(log_text)
else:
logger.debug(log_text)
prompts.pop('job_type')
node.char_prompts = prompts
node.save()

View File

@@ -341,7 +341,7 @@ class JobTemplate(UnifiedJobTemplate, JobOptions, SurveyJobTemplateMixin, Resour
# that of job template launch, so prompting_needed should # that of job template launch, so prompting_needed should
# not block a provisioning callback from creating/launching jobs. # not block a provisioning callback from creating/launching jobs.
if callback_extra_vars is None: if callback_extra_vars is None:
for ask_field_name in set(self.ask_mapping.values()): for ask_field_name in set(self.get_ask_mapping().values()):
if getattr(self, ask_field_name): if getattr(self, ask_field_name):
prompting_needed = True prompting_needed = True
break break
@@ -359,7 +359,7 @@ class JobTemplate(UnifiedJobTemplate, JobOptions, SurveyJobTemplateMixin, Resour
rejected_data['extra_vars'] = rejected_vars rejected_data['extra_vars'] = rejected_vars
# Handle all the other fields that follow the simple prompting rule # Handle all the other fields that follow the simple prompting rule
for field_name, ask_field_name in self.ask_mapping.items(): for field_name, ask_field_name in self.get_ask_mapping().items():
if field_name not in kwargs or field_name == 'extra_vars' or kwargs[field_name] is None: if field_name not in kwargs or field_name == 'extra_vars' or kwargs[field_name] is None:
continue continue
@@ -370,7 +370,7 @@ class JobTemplate(UnifiedJobTemplate, JobOptions, SurveyJobTemplateMixin, Resour
if isinstance(field, models.ManyToManyField): if isinstance(field, models.ManyToManyField):
old_value = set(old_value.all()) old_value = set(old_value.all())
if getattr(self, '_deprecated_credential_launch', False): if getattr(self, '_deprecated_credential_launch', False):
# pass # TODO: remove this code branch when support for `extra_credentials` goes away
new_value = set(kwargs[field_name]) new_value = set(kwargs[field_name])
else: else:
new_value = set(kwargs[field_name]) - old_value new_value = set(kwargs[field_name]) - old_value
@@ -859,7 +859,7 @@ class LaunchTimeConfig(BaseModel):
def prompts_dict(self, display=False): def prompts_dict(self, display=False):
data = {} data = {}
for prompt_name in JobTemplate.ask_mapping.keys(): for prompt_name in JobTemplate.get_ask_mapping().keys():
try: try:
field = self._meta.get_field(prompt_name) field = self._meta.get_field(prompt_name)
except FieldDoesNotExist: except FieldDoesNotExist:
@@ -919,7 +919,7 @@ class LaunchTimeConfig(BaseModel):
return None return None
for field_name in JobTemplate.ask_mapping.keys(): for field_name in JobTemplate.get_ask_mapping().keys():
try: try:
LaunchTimeConfig._meta.get_field(field_name) LaunchTimeConfig._meta.get_field(field_name)
except FieldDoesNotExist: except FieldDoesNotExist:
@@ -948,7 +948,7 @@ class JobLaunchConfig(LaunchTimeConfig):
launching with those prompts launching with those prompts
''' '''
prompts = self.prompts_dict() prompts = self.prompts_dict()
for field_name, ask_field_name in template.ask_mapping.items(): for field_name, ask_field_name in template.get_ask_mapping().items():
if field_name in prompts and not getattr(template, ask_field_name): if field_name in prompts and not getattr(template, ask_field_name):
return True return True
else: else:

View File

@@ -36,8 +36,7 @@ from awx.main.models.mixins import ResourceMixin, TaskManagerUnifiedJobMixin
from awx.main.utils import ( from awx.main.utils import (
decrypt_field, _inventory_updates, decrypt_field, _inventory_updates,
copy_model_by_class, copy_m2m_relationships, copy_model_by_class, copy_m2m_relationships,
get_type_for_model, parse_yaml_or_json, get_type_for_model, parse_yaml_or_json
cached_subclassproperty
) )
from awx.main.redact import UriCleaner, REPLACE_STR from awx.main.redact import UriCleaner, REPLACE_STR
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
@@ -395,17 +394,16 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, Notificatio
return unified_job return unified_job
@cached_subclassproperty @classmethod
def ask_mapping(cls): def get_ask_mapping(cls):
'''
Creates dictionary that maps the unified job field (keys)
to the field that enables prompting for the field (values)
'''
mapping = {} mapping = {}
for field in cls._meta.fields: for field in cls._meta.fields:
if not isinstance(field, AskForField): if isinstance(field, AskForField):
continue mapping[field.allows_field] = field.name
if field.allows_field == '__default__':
allows_field = field.name[len('ask_'):-len('_on_launch')]
else:
allows_field = field.allows_field
mapping[allows_field] = field.name
return mapping return mapping
@classmethod @classmethod
@@ -862,7 +860,7 @@ class UnifiedJob(PolymorphicModel, PasswordFieldsModel, CommonModelNameNotUnique
JobLaunchConfig = self._meta.get_field('launch_config').related_model JobLaunchConfig = self._meta.get_field('launch_config').related_model
config = JobLaunchConfig(job=self) config = JobLaunchConfig(job=self)
for field_name, value in kwargs.items(): for field_name, value in kwargs.items():
if (field_name not in self.unified_job_template.ask_mapping and field_name != 'survey_passwords'): if (field_name not in self.unified_job_template.get_ask_mapping() and field_name != 'survey_passwords'):
raise Exception('Unrecognized launch config field {}.'.format(field_name)) raise Exception('Unrecognized launch config field {}.'.format(field_name))
if field_name == 'credentials': if field_name == 'credentials':
continue continue

View File

@@ -26,6 +26,7 @@ from awx.main.models.rbac import (
from awx.main.fields import ImplicitRoleField from awx.main.fields import ImplicitRoleField
from awx.main.models.mixins import ResourceMixin, SurveyJobTemplateMixin, SurveyJobMixin from awx.main.models.mixins import ResourceMixin, SurveyJobTemplateMixin, SurveyJobMixin
from awx.main.models.jobs import LaunchTimeConfig from awx.main.models.jobs import LaunchTimeConfig
from awx.main.models.credential import Credential
from awx.main.redact import REPLACE_STR from awx.main.redact import REPLACE_STR
from awx.main.fields import JSONField from awx.main.fields import JSONField
@@ -130,7 +131,6 @@ class WorkflowJobTemplateNode(WorkflowNodeBase):
allowed_creds = [] allowed_creds = []
for field_name in self._get_workflow_job_field_names(): for field_name in self._get_workflow_job_field_names():
if field_name == 'credentials': if field_name == 'credentials':
Credential = self._meta.get_field('credentials').related_model
for cred in self.credentials.all(): for cred in self.credentials.all():
if user.can_access(Credential, 'use', cred): if user.can_access(Credential, 'use', cred):
allowed_creds.append(cred) allowed_creds.append(cred)

View File

@@ -91,9 +91,6 @@ class TestWorkflowJobTemplateNodeSerializerGetRelated():
'always_nodes', 'always_nodes',
]) ])
def test_get_related(self, test_get_related, workflow_job_template_node, related_resource_name): def test_get_related(self, test_get_related, workflow_job_template_node, related_resource_name):
serializer = WorkflowJobTemplateNodeSerializer()
print serializer.get_related(workflow_job_template_node)
# import pdb; pdb.set_trace()
test_get_related(WorkflowJobTemplateNodeSerializer, test_get_related(WorkflowJobTemplateNodeSerializer,
workflow_job_template_node, workflow_job_template_node,
'workflow_job_template_nodes', 'workflow_job_template_nodes',

View File

@@ -104,5 +104,5 @@ def test_job_template_can_start_with_callback_extra_vars_provided(job_template_f
def test_ask_mapping_integrity(): def test_ask_mapping_integrity():
assert 'credentials' in JobTemplate.ask_mapping assert 'credentials' in JobTemplate.get_ask_mapping()
assert JobTemplate.ask_mapping['job_tags'] == 'ask_tags_on_launch' assert JobTemplate.get_ask_mapping()['job_tags'] == 'ask_tags_on_launch'

View File

@@ -13,7 +13,7 @@ from awx.main.models import (
@pytest.mark.survey @pytest.mark.survey
class SurveyVariableValidation: class SurveyVariableValidation:
def test_survey_answers_as_string(self, job_template_factory): def test_survey_answers_as_string(self, job_template_factory):
objects = job_template_factory( objects = job_template_factory(
'job-template-with-survey', 'job-template-with-survey',

View File

@@ -1,7 +1,7 @@
import pytest import pytest
from awx.main.models import SystemJobTemplate from awx.main.models import SystemJobTemplate
@pytest.mark.parametrize("extra_data", [ @pytest.mark.parametrize("extra_data", [
'{ "days": 1 }', '{ "days": 1 }',

View File

@@ -233,5 +233,5 @@ class TestWorkflowJobNodeJobKWARGS:
assert job_node_no_prompts.get_job_kwargs() == self.kwargs_base assert job_node_no_prompts.get_job_kwargs() == self.kwargs_base
def test_ask_mapping_integrity(): def test_get_ask_mapping_integrity():
assert WorkflowJobTemplate.ask_mapping.keys() == ['extra_vars'] assert WorkflowJobTemplate.get_ask_mapping().keys() == ['extra_vars']

View File

@@ -47,7 +47,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore',
'extract_ansible_vars', 'get_search_fields', 'get_system_task_capacity', 'extract_ansible_vars', 'get_search_fields', 'get_system_task_capacity',
'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict', 'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict',
'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest', 'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest',
'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError', 'cached_subclassproperty',] 'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',]
def get_object_or_400(klass, *args, **kwargs): def get_object_or_400(klass, *args, **kwargs):
@@ -935,17 +935,3 @@ def has_model_field_prefetched(model_obj, field_name):
# NOTE: Update this function if django internal implementation changes. # NOTE: Update this function if django internal implementation changes.
return getattr(getattr(model_obj, field_name, None), return getattr(getattr(model_obj, field_name, None),
'prefetch_cache_name', '') in getattr(model_obj, '_prefetched_objects_cache', {}) 'prefetch_cache_name', '') in getattr(model_obj, '_prefetched_objects_cache', {})
class cached_subclassproperty(object):
'''Caches property in subclasses'''
def __init__(self, method):
self.method = method
self.name = method.__name__
def __get__(self, instance, cls):
r = self.method(cls)
if self.name not in cls.__dict__:
setattr(cls, self.name, r)
return r

View File

@@ -45,7 +45,7 @@ extra_vars.
Prompting enablement for several types of credentials is controlled by a single Prompting enablement for several types of credentials is controlled by a single
field. On launch, multiple types of credentials can be provided in their respective fields field. On launch, multiple types of credentials can be provided in their respective fields
inside of `credential`, `vault_credential`, and `extra_credentials`. Providing inside of `credential`, `vault_credential`, and `extra_credentials`. Providing
a credential that requirements password input from the user on launch is credentials that require password input from the user on launch is
allowed, and the password must be provided along-side the credential, of course. allowed, and the password must be provided along-side the credential, of course.
If the job is being spawned using a saved launch configuration, however, If the job is being spawned using a saved launch configuration, however,