Implement splitting logic in inventory & job task code

This commit is contained in:
AlanCoding
2018-08-27 11:08:06 -04:00
parent 44ffcf86de
commit dab678c5cc
8 changed files with 89 additions and 32 deletions

View File

@@ -3028,6 +3028,7 @@ class JobTemplateSerializer(JobTemplateMixin, UnifiedJobTemplateSerializer, JobO
labels = self.reverse('api:job_template_label_list', kwargs={'pk': obj.pk}), labels = self.reverse('api:job_template_label_list', kwargs={'pk': obj.pk}),
object_roles = self.reverse('api:job_template_object_roles_list', kwargs={'pk': obj.pk}), object_roles = self.reverse('api:job_template_object_roles_list', kwargs={'pk': obj.pk}),
instance_groups = self.reverse('api:job_template_instance_groups_list', kwargs={'pk': obj.pk}), instance_groups = self.reverse('api:job_template_instance_groups_list', kwargs={'pk': obj.pk}),
sharded_jobs = self.reverse('api:job_template_sharded_jobs_list', kwargs={'pk': obj.pk}),
)) ))
if self.version > 1: if self.version > 1:
res['copy'] = self.reverse('api:job_template_copy', kwargs={'pk': obj.pk}) res['copy'] = self.reverse('api:job_template_copy', kwargs={'pk': obj.pk})

View File

@@ -26,6 +26,9 @@ string of `?all=1` to return all hosts, including disabled ones.
Specify a query string of `?towervars=1` to add variables Specify a query string of `?towervars=1` to add variables
to the hostvars of each host that specifies its enabled state and database ID. to the hostvars of each host that specifies its enabled state and database ID.
Specify a query string of `?subset=shard2of5` to product an inventory that
has a restricted number of hosts according to the rules of job splitting.
To apply multiple query strings, join them with the `&` character, like `?hostvars=1&all=1`. To apply multiple query strings, join them with the `&` character, like `?hostvars=1&all=1`.
## Host Response ## Host Response

View File

@@ -8,6 +8,7 @@ from awx.api.views import (
JobTemplateDetail, JobTemplateDetail,
JobTemplateLaunch, JobTemplateLaunch,
JobTemplateJobsList, JobTemplateJobsList,
JobTemplateShardedJobsList,
JobTemplateCallback, JobTemplateCallback,
JobTemplateSchedulesList, JobTemplateSchedulesList,
JobTemplateSurveySpec, JobTemplateSurveySpec,
@@ -28,6 +29,7 @@ urls = [
url(r'^(?P<pk>[0-9]+)/$', JobTemplateDetail.as_view(), name='job_template_detail'), url(r'^(?P<pk>[0-9]+)/$', JobTemplateDetail.as_view(), name='job_template_detail'),
url(r'^(?P<pk>[0-9]+)/launch/$', JobTemplateLaunch.as_view(), name='job_template_launch'), url(r'^(?P<pk>[0-9]+)/launch/$', JobTemplateLaunch.as_view(), name='job_template_launch'),
url(r'^(?P<pk>[0-9]+)/jobs/$', JobTemplateJobsList.as_view(), name='job_template_jobs_list'), url(r'^(?P<pk>[0-9]+)/jobs/$', JobTemplateJobsList.as_view(), name='job_template_jobs_list'),
url(r'^(?P<pk>[0-9]+)/sharded_jobs/$', JobTemplateShardedJobsList.as_view(), name='job_template_sharded_jobs_list'),
url(r'^(?P<pk>[0-9]+)/callback/$', JobTemplateCallback.as_view(), name='job_template_callback'), url(r'^(?P<pk>[0-9]+)/callback/$', JobTemplateCallback.as_view(), name='job_template_callback'),
url(r'^(?P<pk>[0-9]+)/schedules/$', JobTemplateSchedulesList.as_view(), name='job_template_schedules_list'), url(r'^(?P<pk>[0-9]+)/schedules/$', JobTemplateSchedulesList.as_view(), name='job_template_schedules_list'),
url(r'^(?P<pk>[0-9]+)/survey_spec/$', JobTemplateSurveySpec.as_view(), name='job_template_survey_spec'), url(r'^(?P<pk>[0-9]+)/survey_spec/$', JobTemplateSurveySpec.as_view(), name='job_template_survey_spec'),

View File

@@ -2452,6 +2452,7 @@ class InventoryScriptView(RetrieveAPIView):
hostvars = bool(request.query_params.get('hostvars', '')) hostvars = bool(request.query_params.get('hostvars', ''))
towervars = bool(request.query_params.get('towervars', '')) towervars = bool(request.query_params.get('towervars', ''))
show_all = bool(request.query_params.get('all', '')) show_all = bool(request.query_params.get('all', ''))
subset = request.query_params.get('subset', '')
if hostname: if hostname:
hosts_q = dict(name=hostname) hosts_q = dict(name=hostname)
if not show_all: if not show_all:
@@ -2461,7 +2462,8 @@ class InventoryScriptView(RetrieveAPIView):
return Response(obj.get_script_data( return Response(obj.get_script_data(
hostvars=hostvars, hostvars=hostvars,
towervars=towervars, towervars=towervars,
show_all=show_all show_all=show_all,
subset=subset
)) ))
@@ -3396,6 +3398,15 @@ class JobTemplateJobsList(SubListCreateAPIView):
return methods return methods
class JobTemplateShardedJobsList(SubListCreateAPIView):
model = WorkflowJob
serializer_class = WorkflowJobListSerializer
parent_model = JobTemplate
relationship = 'sharded_jobs'
parent_key = 'job_template'
class JobTemplateInstanceGroupsList(SubListAttachDetachAPIView): class JobTemplateInstanceGroupsList(SubListAttachDetachAPIView):
model = InstanceGroup model = InstanceGroup

View File

@@ -19,6 +19,9 @@ from django.core.exceptions import ValidationError
from django.utils.timezone import now from django.utils.timezone import now
from django.db.models import Q from django.db.models import Q
# REST Framework
from rest_framework.exceptions import ParseError
# AWX # AWX
from awx.api.versioning import reverse from awx.api.versioning import reverse
from awx.main.constants import CLOUD_PROVIDERS from awx.main.constants import CLOUD_PROVIDERS
@@ -217,67 +220,87 @@ class Inventory(CommonModelNameNotUnique, ResourceMixin, RelatedJobsMixin):
group_children.add(from_group_id) group_children.add(from_group_id)
return group_children_map return group_children_map
def get_script_data(self, hostvars=False, towervars=False, show_all=False): @staticmethod
if show_all: def parse_shard_params(shard_str):
hosts_q = dict() m = re.match(r"shard(?P<offset>\d+)of(?P<step>\d+)", shard_str)
if not m:
raise ParseError(_('Could not parse subset as shard specification.'))
offset = int(m.group('offset'))
step = int(m.group('step'))
if offset > step:
raise ParseError(_('Shard offset must be greater than total number of shards.'))
return (offset, step)
def get_script_data(self, hostvars=False, towervars=False, show_all=False, subset=None):
hosts_kw = dict()
if not show_all:
hosts_kw['enabled'] = True
fetch_fields = ['name', 'id', 'variables']
if towervars:
fetch_fields.append('enabled')
hosts = self.hosts.filter(**hosts_kw).order_by('name').only(*fetch_fields)
if subset:
if not isinstance(subset, six.string_types):
raise ParseError(_('Inventory subset argument must be a string.'))
if subset.startswith('shard'):
offset, step = Inventory.parse_shard_params(subset)
hosts = hosts[offset::step]
else: else:
hosts_q = dict(enabled=True) raise ParseError(_('Subset does not use any supported syntax.'))
data = dict() data = dict()
all_group = data.setdefault('all', dict())
if self.variables_dict: if self.variables_dict:
all_group = data.setdefault('all', dict())
all_group['vars'] = self.variables_dict all_group['vars'] = self.variables_dict
if self.kind == 'smart': if self.kind == 'smart':
if len(self.hosts.all()) == 0: all_group['hosts'] = [host.name for host in hosts]
return {}
else: else:
all_group = data.setdefault('all', dict()) # Keep track of hosts that are members of a group
smart_hosts_qs = self.hosts.filter(**hosts_q).all() grouped_hosts = set([])
smart_hosts = list(smart_hosts_qs.values_list('name', flat=True))
all_group['hosts'] = smart_hosts
else:
# Add hosts without a group to the all group.
groupless_hosts_qs = self.hosts.filter(groups__isnull=True, **hosts_q)
groupless_hosts = list(groupless_hosts_qs.values_list('name', flat=True))
if groupless_hosts:
all_group = data.setdefault('all', dict())
all_group['hosts'] = groupless_hosts
# Build in-memory mapping of groups and their hosts. # Build in-memory mapping of groups and their hosts.
group_hosts_kw = dict(group__inventory_id=self.id, host__inventory_id=self.id) group_hosts_qs = Group.hosts.through.objects.filter(
if 'enabled' in hosts_q: group__inventory_id=self.id,
group_hosts_kw['host__enabled'] = hosts_q['enabled'] host__inventory_id=self.id
group_hosts_qs = Group.hosts.through.objects.filter(**group_hosts_kw) ).values_list('group_id', 'host_id', 'host__name')
group_hosts_qs = group_hosts_qs.values_list('group_id', 'host_id', 'host__name')
group_hosts_map = {} group_hosts_map = {}
for group_id, host_id, host_name in group_hosts_qs: for group_id, host_id, host_name in group_hosts_qs:
group_hostnames = group_hosts_map.setdefault(group_id, []) group_hostnames = group_hosts_map.setdefault(group_id, [])
group_hostnames.append(host_name) group_hostnames.append(host_name)
grouped_hosts.add(host_name)
# Build in-memory mapping of groups and their children. # Build in-memory mapping of groups and their children.
group_parents_qs = Group.parents.through.objects.filter( group_parents_qs = Group.parents.through.objects.filter(
from_group__inventory_id=self.id, from_group__inventory_id=self.id,
to_group__inventory_id=self.id, to_group__inventory_id=self.id,
) ).values_list('from_group_id', 'from_group__name', 'to_group_id')
group_parents_qs = group_parents_qs.values_list('from_group_id', 'from_group__name',
'to_group_id')
group_children_map = {} group_children_map = {}
for from_group_id, from_group_name, to_group_id in group_parents_qs: for from_group_id, from_group_name, to_group_id in group_parents_qs:
group_children = group_children_map.setdefault(to_group_id, []) group_children = group_children_map.setdefault(to_group_id, [])
group_children.append(from_group_name) group_children.append(from_group_name)
# Now use in-memory maps to build up group info. # Now use in-memory maps to build up group info.
for group in self.groups.all(): for group in self.groups.only('name', 'id', 'variables'):
group_info = dict() group_info = dict()
group_info['hosts'] = group_hosts_map.get(group.id, []) group_info['hosts'] = group_hosts_map.get(group.id, [])
group_info['children'] = group_children_map.get(group.id, []) group_info['children'] = group_children_map.get(group.id, [])
group_info['vars'] = group.variables_dict group_info['vars'] = group.variables_dict
data[group.name] = group_info data[group.name] = group_info
# Add ungrouped hosts to all group
all_group['hosts'] = [host.name for host in hosts if host.name not in grouped_hosts]
# Remove any empty groups
for group_name in list(data.keys()):
if not data.get(group_name, {}).get('hosts', []):
data.pop(group_name)
if hostvars: if hostvars:
data.setdefault('_meta', dict()) data.setdefault('_meta', dict())
data['_meta'].setdefault('hostvars', dict()) data['_meta'].setdefault('hostvars', dict())
for host in self.hosts.filter(**hosts_q): for host in hosts:
data['_meta']['hostvars'][host.name] = host.variables_dict data['_meta']['hostvars'][host.name] = host.variables_dict
if towervars: if towervars:
tower_dict = dict(remote_tower_enabled=str(host.enabled).lower(), tower_dict = dict(remote_tower_enabled=str(host.enabled).lower(),

View File

@@ -118,7 +118,7 @@ class TaskManager():
kv = spawn_node.get_job_kwargs() kv = spawn_node.get_job_kwargs()
job = spawn_node.unified_job_template.create_unified_job(**kv) job = spawn_node.unified_job_template.create_unified_job(**kv)
if 'job_shard' in spawn_node.ancestor_artifacts: if 'job_shard' in spawn_node.ancestor_artifacts:
job.name = "{} - {}".format(job.name, spawn_node.ancestor_artifacts['job_shard'] + 1) job.name = six.text_type("{} - {}").format(job.name, spawn_node.ancestor_artifacts['job_shard'] + 1)
job.save() job.save()
spawn_node.job = job spawn_node.job = job
spawn_node.save() spawn_node.save()

View File

@@ -825,7 +825,16 @@ class BaseTask(object):
return False return False
def build_inventory(self, instance, **kwargs): def build_inventory(self, instance, **kwargs):
json_data = json.dumps(instance.inventory.get_script_data(hostvars=True)) workflow_job = instance.get_workflow_job()
if workflow_job and workflow_job.job_template_id:
shard_address = 'shard{0}of{1}'.format(
instance.unified_job_node.ancestor_artifacts['job_shard'],
workflow_job.workflow_job_nodes.count()
)
script_data = instance.inventory.get_script_data(hostvars=True, subset=shard_address)
else:
script_data = instance.inventory.get_script_data(hostvars=True)
json_data = json.dumps(script_data)
handle, path = tempfile.mkstemp(dir=kwargs.get('private_data_dir', None)) handle, path = tempfile.mkstemp(dir=kwargs.get('private_data_dir', None))
f = os.fdopen(handle, 'w') f = os.fdopen(handle, 'w')
f.write('#! /usr/bin/env python\n# -*- coding: utf-8 -*-\nprint %r\n' % json_data) f.write('#! /usr/bin/env python\n# -*- coding: utf-8 -*-\nprint %r\n' % json_data)

View File

@@ -38,6 +38,14 @@ class TestInventoryScript:
'remote_tower_id': host.id 'remote_tower_id': host.id
} }
def test_shard_subset(self, inventory):
for i in range(3):
inventory.hosts.create(name='host{}'.format(i))
for i in range(3):
assert inventory.get_script_data(subset='shard{}of3'.format(i)) == {
'all': {'hosts': ['host{}'.format(i)]}
}
@pytest.mark.django_db @pytest.mark.django_db
class TestActiveCount: class TestActiveCount: