From dab678c5cc3f451da558068c6ff8d0bf7fb6b6dd Mon Sep 17 00:00:00 2001 From: AlanCoding Date: Mon, 27 Aug 2018 11:08:06 -0400 Subject: [PATCH] Implement splitting logic in inventory & job task code --- awx/api/serializers.py | 1 + .../templates/api/inventory_script_view.md | 3 + awx/api/urls/job_template.py | 2 + awx/api/views/__init__.py | 13 ++- awx/main/models/inventory.py | 81 ++++++++++++------- awx/main/scheduler/task_manager.py | 2 +- awx/main/tasks.py | 11 ++- .../tests/functional/models/test_inventory.py | 8 ++ 8 files changed, 89 insertions(+), 32 deletions(-) diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 8073b89735..e880648467 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -3028,6 +3028,7 @@ class JobTemplateSerializer(JobTemplateMixin, UnifiedJobTemplateSerializer, JobO labels = self.reverse('api:job_template_label_list', kwargs={'pk': obj.pk}), object_roles = self.reverse('api:job_template_object_roles_list', kwargs={'pk': obj.pk}), instance_groups = self.reverse('api:job_template_instance_groups_list', kwargs={'pk': obj.pk}), + sharded_jobs = self.reverse('api:job_template_sharded_jobs_list', kwargs={'pk': obj.pk}), )) if self.version > 1: res['copy'] = self.reverse('api:job_template_copy', kwargs={'pk': obj.pk}) diff --git a/awx/api/templates/api/inventory_script_view.md b/awx/api/templates/api/inventory_script_view.md index 07656c1eff..19cfff28ce 100644 --- a/awx/api/templates/api/inventory_script_view.md +++ b/awx/api/templates/api/inventory_script_view.md @@ -26,6 +26,9 @@ string of `?all=1` to return all hosts, including disabled ones. Specify a query string of `?towervars=1` to add variables to the hostvars of each host that specifies its enabled state and database ID. +Specify a query string of `?subset=shard2of5` to product an inventory that +has a restricted number of hosts according to the rules of job splitting. + To apply multiple query strings, join them with the `&` character, like `?hostvars=1&all=1`. ## Host Response diff --git a/awx/api/urls/job_template.py b/awx/api/urls/job_template.py index b11dbf4fea..9b830d64a7 100644 --- a/awx/api/urls/job_template.py +++ b/awx/api/urls/job_template.py @@ -8,6 +8,7 @@ from awx.api.views import ( JobTemplateDetail, JobTemplateLaunch, JobTemplateJobsList, + JobTemplateShardedJobsList, JobTemplateCallback, JobTemplateSchedulesList, JobTemplateSurveySpec, @@ -28,6 +29,7 @@ urls = [ url(r'^(?P[0-9]+)/$', JobTemplateDetail.as_view(), name='job_template_detail'), url(r'^(?P[0-9]+)/launch/$', JobTemplateLaunch.as_view(), name='job_template_launch'), url(r'^(?P[0-9]+)/jobs/$', JobTemplateJobsList.as_view(), name='job_template_jobs_list'), + url(r'^(?P[0-9]+)/sharded_jobs/$', JobTemplateShardedJobsList.as_view(), name='job_template_sharded_jobs_list'), url(r'^(?P[0-9]+)/callback/$', JobTemplateCallback.as_view(), name='job_template_callback'), url(r'^(?P[0-9]+)/schedules/$', JobTemplateSchedulesList.as_view(), name='job_template_schedules_list'), url(r'^(?P[0-9]+)/survey_spec/$', JobTemplateSurveySpec.as_view(), name='job_template_survey_spec'), diff --git a/awx/api/views/__init__.py b/awx/api/views/__init__.py index de8756ce40..8bd9f25dc2 100644 --- a/awx/api/views/__init__.py +++ b/awx/api/views/__init__.py @@ -2452,6 +2452,7 @@ class InventoryScriptView(RetrieveAPIView): hostvars = bool(request.query_params.get('hostvars', '')) towervars = bool(request.query_params.get('towervars', '')) show_all = bool(request.query_params.get('all', '')) + subset = request.query_params.get('subset', '') if hostname: hosts_q = dict(name=hostname) if not show_all: @@ -2461,7 +2462,8 @@ class InventoryScriptView(RetrieveAPIView): return Response(obj.get_script_data( hostvars=hostvars, towervars=towervars, - show_all=show_all + show_all=show_all, + subset=subset )) @@ -3396,6 +3398,15 @@ class JobTemplateJobsList(SubListCreateAPIView): return methods +class JobTemplateShardedJobsList(SubListCreateAPIView): + + model = WorkflowJob + serializer_class = WorkflowJobListSerializer + parent_model = JobTemplate + relationship = 'sharded_jobs' + parent_key = 'job_template' + + class JobTemplateInstanceGroupsList(SubListAttachDetachAPIView): model = InstanceGroup diff --git a/awx/main/models/inventory.py b/awx/main/models/inventory.py index 175fa40236..26a5be6c3a 100644 --- a/awx/main/models/inventory.py +++ b/awx/main/models/inventory.py @@ -19,6 +19,9 @@ from django.core.exceptions import ValidationError from django.utils.timezone import now from django.db.models import Q +# REST Framework +from rest_framework.exceptions import ParseError + # AWX from awx.api.versioning import reverse from awx.main.constants import CLOUD_PROVIDERS @@ -217,67 +220,87 @@ class Inventory(CommonModelNameNotUnique, ResourceMixin, RelatedJobsMixin): group_children.add(from_group_id) return group_children_map - def get_script_data(self, hostvars=False, towervars=False, show_all=False): - if show_all: - hosts_q = dict() - else: - hosts_q = dict(enabled=True) + @staticmethod + def parse_shard_params(shard_str): + m = re.match(r"shard(?P\d+)of(?P\d+)", shard_str) + if not m: + raise ParseError(_('Could not parse subset as shard specification.')) + offset = int(m.group('offset')) + step = int(m.group('step')) + if offset > step: + raise ParseError(_('Shard offset must be greater than total number of shards.')) + return (offset, step) + + def get_script_data(self, hostvars=False, towervars=False, show_all=False, subset=None): + hosts_kw = dict() + if not show_all: + hosts_kw['enabled'] = True + fetch_fields = ['name', 'id', 'variables'] + if towervars: + fetch_fields.append('enabled') + hosts = self.hosts.filter(**hosts_kw).order_by('name').only(*fetch_fields) + if subset: + if not isinstance(subset, six.string_types): + raise ParseError(_('Inventory subset argument must be a string.')) + if subset.startswith('shard'): + offset, step = Inventory.parse_shard_params(subset) + hosts = hosts[offset::step] + else: + raise ParseError(_('Subset does not use any supported syntax.')) + data = dict() + all_group = data.setdefault('all', dict()) if self.variables_dict: - all_group = data.setdefault('all', dict()) all_group['vars'] = self.variables_dict + if self.kind == 'smart': - if len(self.hosts.all()) == 0: - return {} - else: - all_group = data.setdefault('all', dict()) - smart_hosts_qs = self.hosts.filter(**hosts_q).all() - smart_hosts = list(smart_hosts_qs.values_list('name', flat=True)) - all_group['hosts'] = smart_hosts + all_group['hosts'] = [host.name for host in hosts] else: - # Add hosts without a group to the all group. - groupless_hosts_qs = self.hosts.filter(groups__isnull=True, **hosts_q) - groupless_hosts = list(groupless_hosts_qs.values_list('name', flat=True)) - if groupless_hosts: - all_group = data.setdefault('all', dict()) - all_group['hosts'] = groupless_hosts + # Keep track of hosts that are members of a group + grouped_hosts = set([]) # Build in-memory mapping of groups and their hosts. - group_hosts_kw = dict(group__inventory_id=self.id, host__inventory_id=self.id) - if 'enabled' in hosts_q: - group_hosts_kw['host__enabled'] = hosts_q['enabled'] - group_hosts_qs = Group.hosts.through.objects.filter(**group_hosts_kw) - group_hosts_qs = group_hosts_qs.values_list('group_id', 'host_id', 'host__name') + group_hosts_qs = Group.hosts.through.objects.filter( + group__inventory_id=self.id, + host__inventory_id=self.id + ).values_list('group_id', 'host_id', 'host__name') group_hosts_map = {} for group_id, host_id, host_name in group_hosts_qs: group_hostnames = group_hosts_map.setdefault(group_id, []) group_hostnames.append(host_name) + grouped_hosts.add(host_name) # Build in-memory mapping of groups and their children. group_parents_qs = Group.parents.through.objects.filter( from_group__inventory_id=self.id, to_group__inventory_id=self.id, - ) - group_parents_qs = group_parents_qs.values_list('from_group_id', 'from_group__name', - 'to_group_id') + ).values_list('from_group_id', 'from_group__name', 'to_group_id') group_children_map = {} for from_group_id, from_group_name, to_group_id in group_parents_qs: group_children = group_children_map.setdefault(to_group_id, []) group_children.append(from_group_name) # Now use in-memory maps to build up group info. - for group in self.groups.all(): + for group in self.groups.only('name', 'id', 'variables'): group_info = dict() group_info['hosts'] = group_hosts_map.get(group.id, []) group_info['children'] = group_children_map.get(group.id, []) group_info['vars'] = group.variables_dict data[group.name] = group_info + # Add ungrouped hosts to all group + all_group['hosts'] = [host.name for host in hosts if host.name not in grouped_hosts] + + # Remove any empty groups + for group_name in list(data.keys()): + if not data.get(group_name, {}).get('hosts', []): + data.pop(group_name) + if hostvars: data.setdefault('_meta', dict()) data['_meta'].setdefault('hostvars', dict()) - for host in self.hosts.filter(**hosts_q): + for host in hosts: data['_meta']['hostvars'][host.name] = host.variables_dict if towervars: tower_dict = dict(remote_tower_enabled=str(host.enabled).lower(), diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index a5bfccb967..4f21e56903 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -118,7 +118,7 @@ class TaskManager(): kv = spawn_node.get_job_kwargs() job = spawn_node.unified_job_template.create_unified_job(**kv) if 'job_shard' in spawn_node.ancestor_artifacts: - job.name = "{} - {}".format(job.name, spawn_node.ancestor_artifacts['job_shard'] + 1) + job.name = six.text_type("{} - {}").format(job.name, spawn_node.ancestor_artifacts['job_shard'] + 1) job.save() spawn_node.job = job spawn_node.save() diff --git a/awx/main/tasks.py b/awx/main/tasks.py index 0fdbef8036..db573b0b68 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -825,7 +825,16 @@ class BaseTask(object): return False def build_inventory(self, instance, **kwargs): - json_data = json.dumps(instance.inventory.get_script_data(hostvars=True)) + workflow_job = instance.get_workflow_job() + if workflow_job and workflow_job.job_template_id: + shard_address = 'shard{0}of{1}'.format( + instance.unified_job_node.ancestor_artifacts['job_shard'], + workflow_job.workflow_job_nodes.count() + ) + script_data = instance.inventory.get_script_data(hostvars=True, subset=shard_address) + else: + script_data = instance.inventory.get_script_data(hostvars=True) + json_data = json.dumps(script_data) handle, path = tempfile.mkstemp(dir=kwargs.get('private_data_dir', None)) f = os.fdopen(handle, 'w') f.write('#! /usr/bin/env python\n# -*- coding: utf-8 -*-\nprint %r\n' % json_data) diff --git a/awx/main/tests/functional/models/test_inventory.py b/awx/main/tests/functional/models/test_inventory.py index 57365b914b..34eb1d7b13 100644 --- a/awx/main/tests/functional/models/test_inventory.py +++ b/awx/main/tests/functional/models/test_inventory.py @@ -38,6 +38,14 @@ class TestInventoryScript: 'remote_tower_id': host.id } + def test_shard_subset(self, inventory): + for i in range(3): + inventory.hosts.create(name='host{}'.format(i)) + for i in range(3): + assert inventory.get_script_data(subset='shard{}of3'.format(i)) == { + 'all': {'hosts': ['host{}'.format(i)]} + } + @pytest.mark.django_db class TestActiveCount: