Set priority rules for find_matching_hosts.

This commit is contained in:
Aaron Tan
2017-06-14 15:45:15 -04:00
parent 422950f45d
commit 0dae058bef
3 changed files with 60 additions and 31 deletions

View File

@@ -3046,40 +3046,29 @@ class JobTemplateCallback(GenericAPIView):
# Find the host objects to search for a match. # Find the host objects to search for a match.
obj = self.get_object() obj = self.get_object()
hosts = obj.inventory.hosts.all() hosts = obj.inventory.hosts.all()
# First try for an exact match on the name. # Populate host_mappings
try: host_mappings = {}
return set([hosts.get(name__in=remote_hosts)])
except (Host.DoesNotExist, Host.MultipleObjectsReturned):
pass
# Next, try matching based on name or ansible_host variables.
matches = set()
for host in hosts: for host in hosts:
for host_var in ['ansible_ssh_host', 'ansible_host']: host_name = host.get_effective_host_name()
ansible_host = host.variables_dict.get(host_var, '') host_mappings.setdefault(host_name, [])
if ansible_host in remote_hosts: host_mappings[host_name].append(host)
matches.add(host) # Try finding direct match
if host.name != ansible_host and host.name in remote_hosts: matches = set()
matches.add(host) for host_name in remote_hosts:
if host_name in host_mappings:
matches.update(host_mappings[host_name])
if len(matches) == 1: if len(matches) == 1:
return matches return matches
# Try to resolve forward addresses for each host to find matches. # Try to resolve forward addresses for each host to find matches.
for host in hosts: for host_name in host_mappings:
hostnames = set([host.name]) try:
for host_var in ['ansible_ssh_host', 'ansible_host']: result = socket.getaddrinfo(host_name, None)
ansible_host = host.variables_dict.get(host_var, '') possible_ips = set(x[4][0] for x in result)
if ansible_host: possible_ips.discard(host_name)
hostnames.add(ansible_host) if possible_ips and possible_ips & remote_hosts:
for hostname in hostnames: matches.update(host_mappings[host_name])
try: except socket.gaierror:
result = socket.getaddrinfo(hostname, None) pass
possible_ips = set(x[4][0] for x in result)
possible_ips.discard(hostname)
if possible_ips and possible_ips & remote_hosts:
matches.add(host)
except socket.gaierror:
pass
# Return all matches found.
return matches return matches
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):

View File

@@ -527,6 +527,18 @@ class Host(CommonModelNameNotUnique):
self.ansible_facts[module] = facts self.ansible_facts[module] = facts
self.save() self.save()
def get_effective_host_name(self):
'''
Return the name of the host that will be used in actual ansible
command run.
'''
host_name = self.name
if 'ansible_ssh_host' in self.variables_dict:
host_name = self.variables_dict['ansible_ssh_host']
if 'ansible_host' in self.variables_dict:
host_name = self.variables_dict['ansible_host']
return host_name
class Group(CommonModelNameNotUnique): class Group(CommonModelNameNotUnique):
''' '''

View File

@@ -3,7 +3,7 @@ import yaml
from awx.api.serializers import JobLaunchSerializer from awx.api.serializers import JobLaunchSerializer
from awx.main.models.credential import Credential from awx.main.models.credential import Credential
from awx.main.models.inventory import Inventory from awx.main.models.inventory import Inventory, Host
from awx.main.models.jobs import Job, JobTemplate from awx.main.models.jobs import Job, JobTemplate
from awx.api.versioning import reverse from awx.api.versioning import reverse
@@ -431,3 +431,31 @@ def test_callback_ignore_unprompted_extra_var(mocker, survey_spec_factory, job_t
'limit': 'single-host'},) 'limit': 'single-host'},)
mock_job.signal_start.assert_called_once() mock_job.signal_start.assert_called_once()
@pytest.mark.django_db
@pytest.mark.job_runtime_vars
def test_callback_find_matching_hosts(mocker, get, job_template_prompts, admin_user):
job_template = job_template_prompts(False)
job_template.host_config_key = "foo"
job_template.save()
host_with_alias = Host(name='localhost', inventory=job_template.inventory)
host_with_alias.save()
with mocker.patch('awx.main.access.BaseAccess.check_license'):
r = get(reverse('api:job_template_callback', kwargs={'pk': job_template.pk}),
user=admin_user, expect=200)
assert tuple(r.data['matching_hosts']) == ('localhost',)
@pytest.mark.django_db
@pytest.mark.job_runtime_vars
def test_callback_extra_var_takes_priority_over_host_name(mocker, get, job_template_prompts, admin_user):
job_template = job_template_prompts(False)
job_template.host_config_key = "foo"
job_template.save()
host_with_alias = Host(name='localhost', variables={'ansible_host': 'foobar'}, inventory=job_template.inventory)
host_with_alias.save()
with mocker.patch('awx.main.access.BaseAccess.check_license'):
r = get(reverse('api:job_template_callback', kwargs={'pk': job_template.pk}),
user=admin_user, expect=200)
assert not r.data['matching_hosts']