Validate against ansible variables on ad hoc launch

Share code between this check for ad hoc and JT callback
This commit is contained in:
AlanCoding 2017-10-05 12:14:05 -04:00
parent 02e3f45422
commit eacbeef660
No known key found for this signature in database
GPG Key ID: FD2C3C012A72926B
5 changed files with 40 additions and 13 deletions

View File

@ -45,7 +45,7 @@ from awx.main.fields import ImplicitRoleField
from awx.main.utils import (
get_type_for_model, get_model_for_type, timestamp_apiformat,
camelcase_to_underscore, getattrd, parse_yaml_or_json,
has_model_field_prefetched)
has_model_field_prefetched, extract_ansible_vars)
from awx.main.utils.filters import SmartFilter
from awx.main.validators import vars_validate_or_raise
@ -2759,6 +2759,14 @@ class AdHocCommandSerializer(UnifiedJobSerializer):
ret['name'] = obj.module_name
return ret
def validate_extra_vars(self, value):
redacted_extra_vars, removed_vars = extract_ansible_vars(value)
if removed_vars:
raise serializers.ValidationError(_(
"Variables {} are prohibited from use in ad hoc commands."
).format(",".join(removed_vars)))
return vars_validate_or_raise(value)
class AdHocCommandCancelSerializer(AdHocCommandSerializer):

View File

@ -69,7 +69,7 @@ from awx.conf.license import get_license, feature_enabled, feature_exists, Licen
from awx.main.models import * # noqa
from awx.main.utils import * # noqa
from awx.main.utils import (
callback_filter_out_ansible_extra_vars,
extract_ansible_vars,
decrypt_field,
)
from awx.main.utils.filters import SmartFilter
@ -3160,7 +3160,8 @@ class JobTemplateCallback(GenericAPIView):
# Everything is fine; actually create the job.
kv = {"limit": limit, "launch_type": 'callback'}
if extra_vars is not None and job_template.ask_variables_on_launch:
kv['extra_vars'] = callback_filter_out_ansible_extra_vars(extra_vars)
extra_vars_redacted, removed = extract_ansible_vars(extra_vars)
kv['extra_vars'] = extra_vars_redacted
with transaction.atomic():
job = job_template.create_job(**kv)

View File

@ -56,7 +56,7 @@ from awx.main.utils import (get_ansible_version, get_ssh_version, decrypt_field,
check_proot_installed, build_proot_temp_dir, get_licenser,
wrap_args_with_proot, get_system_task_capacity, OutputEventFilter,
parse_yaml_or_json, ignore_inventory_computed_fields, ignore_inventory_group_removal,
get_type_for_model)
get_type_for_model, extract_ansible_vars)
from awx.main.utils.reload import restart_local_services, stop_local_services
from awx.main.utils.handlers import configure_external_logger
from awx.main.consumers import emit_channel_notification
@ -2139,8 +2139,11 @@ class RunAdHocCommand(BaseTask):
args.append('-%s' % ('v' * min(5, ad_hoc_command.verbosity)))
if ad_hoc_command.extra_vars_dict:
if ad_hoc_command.extra_vars_dict.get('ansible_connection') == 'local':
raise ValueError(_("unable to use the `local` connection plugin with ad hoc commands"))
redacted_extra_vars, removed_vars = extract_ansible_vars(ad_hoc_command.extra_vars_dict)
if removed_vars:
raise ValueError(_(
"unable to use {} variables with ad hoc commands"
).format(",".format(removed_vars)))
args.extend(['-e', json.dumps(ad_hoc_command.extra_vars_dict)])

View File

@ -5,6 +5,7 @@
import os
import pytest
from uuid import uuid4
import json
from django.core.cache import cache
@ -115,3 +116,12 @@ def test_memoize_parameter_error():
with pytest.raises(common.IllegalArgumentError):
fn()
def test_extract_ansible_vars():
my_dict = {
"foobar": "baz",
"ansible_connetion_setting": "1928"
}
redacted, var_list = common.extract_ansible_vars(json.dumps(my_dict))
assert var_list == set(['ansible_connetion_setting'])
assert redacted == {"foobar": "baz"}

View File

@ -41,7 +41,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore',
'ignore_inventory_computed_fields', 'ignore_inventory_group_removal',
'_inventory_updates', 'get_pk_from_dict', 'getattrd', 'NoDefaultProvided',
'get_current_apps', 'set_current_apps', 'OutputEventFilter',
'callback_filter_out_ansible_extra_vars', 'get_search_fields', 'get_system_task_capacity',
'extract_ansible_vars', 'get_search_fields', 'get_system_task_capacity',
'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict',
'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest',
'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',]
@ -877,13 +877,18 @@ class OutputEventFilter(object):
self._current_event_data = None
def callback_filter_out_ansible_extra_vars(extra_vars):
extra_vars_redacted = {}
def is_ansible_variable(key):
return key.startswith('ansible_')
def extract_ansible_vars(extra_vars):
extra_vars = parse_yaml_or_json(extra_vars)
for key, value in extra_vars.iteritems():
if not key.startswith('ansible_'):
extra_vars_redacted[key] = value
return extra_vars_redacted
ansible_vars = set([])
for key in extra_vars.keys():
if is_ansible_variable(key):
extra_vars.pop(key)
ansible_vars.add(key)
return (extra_vars, ansible_vars)
def get_search_fields(model):