Files
awx/awx/main/tests/unit/utils/test_common.py
2024-03-05 11:38:14 +00:00

340 lines
14 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (c) 2017 Ansible, Inc.
# All Rights Reserved.
import os
import re
import pytest
from uuid import uuid4
import json
import yaml
from unittest import mock
from rest_framework.exceptions import ParseError
from ansible_base.lib.utils.models import get_type_for_model
from awx.main.utils import common
from awx.api.validators import HostnameRegexValidator
from awx.main.models import Job, AdHocCommand, InventoryUpdate, ProjectUpdate, SystemJob, WorkflowJob, Inventory, JobTemplate, UnifiedJobTemplate, UnifiedJob
from django.core.exceptions import ValidationError
from django.utils.regex_helper import _lazy_re_compile
@pytest.mark.parametrize(
'input_, output',
[
({"foo": "bar"}, {"foo": "bar"}),
('{"foo": "bar"}', {"foo": "bar"}),
('---\nfoo: bar', {"foo": "bar"}),
(4399, {}),
],
)
def test_parse_yaml_or_json(input_, output):
assert common.parse_yaml_or_json(input_) == output
def test_recursive_vars_not_allowed():
rdict = {}
rdict['a'] = rdict
# YAML dumper will use a tag to give recursive data
data = yaml.dump(rdict, default_flow_style=False)
with pytest.raises(ParseError) as exc:
common.parse_yaml_or_json(data, silent_failure=False)
assert 'Circular reference detected' in str(exc)
class TestParserExceptions:
@staticmethod
def json_error(data):
try:
json.loads(data)
return None
except Exception as e:
return str(e)
@staticmethod
def yaml_error(data):
try:
yaml.safe_load(data)
return None
except Exception as e:
return str(e)
def test_invalid_JSON_and_YAML(self):
data = "{key:val"
with pytest.raises(ParseError) as exc:
common.parse_yaml_or_json(data, silent_failure=False)
message = str(exc.value)
assert "Cannot parse as" in message
assert self.json_error(data) in message
assert self.yaml_error(data) in message
def test_invalid_vars_type(self):
data = "[1, 2, 3]"
with pytest.raises(ParseError) as exc:
common.parse_yaml_or_json(data, silent_failure=False)
message = str(exc.value)
assert "Cannot parse as" in message
assert "Input type `list` is not a dictionary" in message
def test_set_environ():
key = str(uuid4())
old_environ = os.environ.copy()
with common.set_environ(**{key: 'bar'}):
assert os.environ[key] == 'bar'
assert set(os.environ.keys()) - set(old_environ.keys()) == set([key])
assert os.environ == old_environ
assert key not in os.environ
TEST_MODELS = [
(Job, 'job'),
(AdHocCommand, 'ad_hoc_command'),
(InventoryUpdate, 'inventory_update'),
(ProjectUpdate, 'project_update'),
(SystemJob, 'system_job'),
(WorkflowJob, 'workflow_job'),
(UnifiedJob, 'unified_job'),
(Inventory, 'inventory'),
(JobTemplate, 'job_template'),
(UnifiedJobTemplate, 'unified_job_template'),
]
# Cases relied on for scheduler dependent jobs list
@pytest.mark.parametrize('model,name', TEST_MODELS)
def test_get_type_for_model(model, name):
assert get_type_for_model(model) == name
def test_get_model_for_invalid_type():
with pytest.raises(LookupError):
common.get_model_for_type('foobar')
@pytest.mark.parametrize("model_type,model_class", [(name, cls) for cls, name in TEST_MODELS])
def test_get_model_for_valid_type(model_type, model_class):
assert common.get_model_for_type(model_type) == model_class
def test_is_testing():
assert common.is_testing() is True
@pytest.mark.parametrize("model_type,model_class", [(name, cls) for cls, name in TEST_MODELS])
def test_get_capacity_type(model_type, model_class):
if model_type in ('job', 'ad_hoc_command', 'inventory_update', 'job_template'):
expectation = 'execution'
elif model_type in ('project_update', 'system_job'):
expectation = 'control'
else:
expectation = None
if model_type in ('unified_job', 'unified_job_template', 'inventory'):
with pytest.raises(RuntimeError):
common.get_capacity_type(model_class)
else:
assert common.get_capacity_type(model_class) == expectation
assert common.get_capacity_type(model_class()) == expectation
@pytest.fixture
def memoized_function(mocker, mock_cache):
with mock.patch('awx.main.utils.common.get_memoize_cache', return_value=mock_cache):
@common.memoize(track_function=True)
def myfunction(key, value):
if key not in myfunction.calls:
myfunction.calls[key] = 0
myfunction.calls[key] += 1
if myfunction.calls[key] == 1:
return value
else:
return '%s called %s times' % (value, myfunction.calls[key])
myfunction.calls = dict()
return myfunction
def test_memoize_track_function(memoized_function, mock_cache):
assert memoized_function('scott', 'scotterson') == 'scotterson'
assert mock_cache.get('myfunction') == {u'scott-scotterson': 'scotterson'}
assert memoized_function('scott', 'scotterson') == 'scotterson'
assert memoized_function.calls['scott'] == 1
assert memoized_function('john', 'smith') == 'smith'
assert mock_cache.get('myfunction') == {u'scott-scotterson': 'scotterson', u'john-smith': 'smith'}
assert memoized_function('john', 'smith') == 'smith'
assert memoized_function.calls['john'] == 1
def test_memoize_delete(memoized_function, mock_cache):
assert memoized_function('john', 'smith') == 'smith'
assert memoized_function('john', 'smith') == 'smith'
assert memoized_function.calls['john'] == 1
assert mock_cache.get('myfunction') == {u'john-smith': 'smith'}
with mock.patch('awx.main.utils.common.memoize_delete', side_effect=mock_cache.delete):
common.memoize_delete('myfunction')
assert mock_cache.get('myfunction') is None
assert memoized_function('john', 'smith') == 'smith called 2 times'
assert memoized_function.calls['john'] == 2
def test_memoize_parameter_error():
with pytest.raises(common.IllegalArgumentError):
@common.memoize(cache_key='foo', track_function=True)
def fn():
return
def test_extract_ansible_vars():
my_dict = {"foobar": "baz", "ansible_connetion_setting": "1928"}
redacted, var_list = common.extract_ansible_vars(json.dumps(my_dict))
assert var_list == set(['ansible_connetion_setting'])
assert redacted == {"foobar": "baz"}
@pytest.mark.parametrize(
'scm_type, url, username, password, check_special_cases, scp_format, expected',
[
# General/random cases
('git', '', True, True, True, False, ''),
('git', 'git://example.com/foo.git', True, True, True, False, 'git://example.com/foo.git'),
('git', 'http://example.com/foo.git', True, True, True, False, 'http://example.com/foo.git'),
('git', 'example.com:bar.git', True, True, True, False, 'git+ssh://example.com/bar.git'),
('git', 'user@example.com:bar.git', True, True, True, False, 'git+ssh://user@example.com/bar.git'),
('git', '127.0.0.1:bar.git', True, True, True, False, 'git+ssh://127.0.0.1/bar.git'),
('git', 'git+ssh://127.0.0.1/bar.git', True, True, True, True, '127.0.0.1:bar.git'),
('git', 'ssh://127.0.0.1:22/bar.git', True, True, True, False, 'ssh://127.0.0.1:22/bar.git'),
('git', 'ssh://root@127.0.0.1:22/bar.git', True, True, True, False, 'ssh://root@127.0.0.1:22/bar.git'),
('git', 'some/path', True, True, True, False, 'file:///some/path'),
('git', '/some/path', True, True, True, False, 'file:///some/path'),
# Invalid URLs - ensure we error properly
('cvs', 'anything', True, True, True, False, ValueError('Unsupported SCM type "cvs"')),
('svn', 'anything-without-colon-slash-slash', True, True, True, False, ValueError('Invalid svn URL')),
('git', 'http://example.com:123invalidport/foo.git', True, True, True, False, ValueError('Invalid git URL')),
('git', 'git+ssh://127.0.0.1/bar.git', True, True, True, False, ValueError('Unsupported git URL')),
('git', 'git@example.com:3000:/git/repo.git', True, True, True, False, ValueError('Invalid git URL')),
('insights', 'git://example.com/foo.git', True, True, True, False, ValueError('Unsupported insights URL')),
('svn', 'file://example/path', True, True, True, False, ValueError('Unsupported host "example" for file:// URL')),
('svn', 'svn:///example', True, True, True, False, ValueError('Host is required for svn URL')),
# Username/password cases
('git', 'https://example@example.com/bar.git', False, True, True, False, 'https://example.com/bar.git'),
('git', 'https://example@example.com/bar.git', 'user', True, True, False, 'https://user@example.com/bar.git'),
('git', 'https://example@example.com/bar.git', 'user:pw', True, True, False, 'https://user%3Apw@example.com/bar.git'),
('git', 'https://example@example.com/bar.git', False, 'pw', True, False, 'https://example.com/bar.git'),
('git', 'https://some:example@example.com/bar.git', True, False, True, False, 'https://some@example.com/bar.git'),
('git', 'https://some:example@example.com/bar.git', False, False, True, False, 'https://example.com/bar.git'),
('git', 'https://example.com/bar.git', 'user', 'pw', True, False, 'https://user:pw@example.com/bar.git'),
('git', 'https://example@example.com/bar.git', False, 'something', True, False, 'https://example.com/bar.git'),
# Special github/bitbucket cases
('git', 'notgit@github.com:ansible/awx.git', True, True, True, False, ValueError('Username must be "git" for SSH access to github.com.')),
(
'git',
'notgit@bitbucket.org:does-not-exist/example.git',
True,
True,
True,
False,
ValueError('Username must be "git" for SSH access to bitbucket.org.'),
),
(
'git',
'notgit@altssh.bitbucket.org:does-not-exist/example.git',
True,
True,
True,
False,
ValueError('Username must be "git" for SSH access to altssh.bitbucket.org.'),
),
('git', 'git:password@github.com:ansible/awx.git', True, True, True, False, 'git+ssh://git@github.com/ansible/awx.git'),
# Disabling the special handling should not raise an error
('git', 'notgit@github.com:ansible/awx.git', True, True, False, False, 'git+ssh://notgit@github.com/ansible/awx.git'),
('git', 'notgit@bitbucket.org:does-not-exist/example.git', True, True, False, False, 'git+ssh://notgit@bitbucket.org/does-not-exist/example.git'),
(
'git',
'notgit@altssh.bitbucket.org:does-not-exist/example.git',
True,
True,
False,
False,
'git+ssh://notgit@altssh.bitbucket.org/does-not-exist/example.git',
),
# awx#12992 - IPv6
('git', 'http://[fd00:1234:2345:6789::11]:3000/foo.git', True, True, True, False, 'http://[fd00:1234:2345:6789::11]:3000/foo.git'),
('git', 'http://foo:bar@[fd00:1234:2345:6789::11]:3000/foo.git', True, True, True, False, 'http://foo:bar@[fd00:1234:2345:6789::11]:3000/foo.git'),
('git', 'example@[fd00:1234:2345:6789::11]:example/foo.git', True, True, True, False, 'git+ssh://example@[fd00:1234:2345:6789::11]/example/foo.git'),
],
)
def test_update_scm_url(scm_type, url, username, password, check_special_cases, scp_format, expected):
if isinstance(expected, Exception):
with pytest.raises(type(expected)) as excinfo:
common.update_scm_url(scm_type, url, username, password, check_special_cases, scp_format)
assert str(excinfo.value) == str(expected)
else:
assert common.update_scm_url(scm_type, url, username, password, check_special_cases, scp_format) == expected
class TestHostnameRegexValidator:
@pytest.fixture
def regex_expr(self):
return '^[a-z0-9][-a-z0-9]*$|^([a-z0-9][-a-z0-9]{0,62}[.])*[a-z0-9][-a-z0-9]{1,62}$'
@pytest.fixture
def re_flags(self):
return re.IGNORECASE
@pytest.fixture
def custom_err_message(self):
return "foobar"
def test_hostame_regex_validator_constructor_with_args(self, regex_expr, re_flags, custom_err_message):
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags, message=custom_err_message)
assert h.regex == _lazy_re_compile(regex_expr, re_flags)
assert h.message == 'foobar'
assert h.code == 'invalid'
assert h.inverse_match == False
assert h.flags == re_flags
def test_hostame_regex_validator_default_constructor(self, regex_expr, re_flags):
h = HostnameRegexValidator()
assert h.regex == _lazy_re_compile(regex_expr, re_flags)
assert h.message == 'Enter a valid value.'
assert h.code == 'invalid'
assert h.inverse_match == False
assert h.flags == re_flags
def test_good_call(self, regex_expr, re_flags):
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags)
assert (h("192.168.56.101"), None)
def test_bad_call(self, regex_expr, re_flags):
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags)
try:
h("@#$%)$#(TUFAS_DG")
except ValidationError as e:
assert e.message is not None
def test_good_call_with_inverse(self, regex_expr, re_flags, inverse_match=True):
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags, inverse_match=inverse_match)
try:
h("1.2.3.4")
except ValidationError as e:
assert e.message is not None
def test_bad_call_with_inverse(self, regex_expr, re_flags, inverse_match=True):
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags, inverse_match=inverse_match)
assert (h("@#$%)$#(TUFAS_DG"), None)