From daf43101766a5d5efe159d36279c3e8a547592fb Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Thu, 26 Aug 2021 07:24:14 -0400 Subject: [PATCH] Clean up work_type processing and fix execution vs control capacity (#10930) * Clean up added work_type processing for mesh_code branch * track both execution and control capacity * Remove unused execution_capacity property * Count all forms of capacity to make test pass * Force jobs to be on execution nodes, updates on control nodes * Introduce capacity_type property to abstract some details out * Update test to cover all job types at same time * Register OpenShift nodes as control types * Remove unqualified consumed_capacity from task manager and make unit tests work * Remove unqualified consumed_capacity from task manager and make unit tests work * Update unit test to execution vs control TM logic changes * Fix bug, else handling for work_type method --- awx/main/managers.py | 12 +++- awx/main/models/ad_hoc_commands.py | 4 -- awx/main/models/ha.py | 10 ++-- awx/main/models/inventory.py | 4 -- awx/main/models/jobs.py | 8 --- awx/main/models/projects.py | 4 -- awx/main/models/unified_jobs.py | 19 ++----- awx/main/scheduler/task_manager.py | 55 +++++++++++++------ awx/main/tasks.py | 31 ++++------- awx/main/tests/functional/conftest.py | 6 +- .../task_management/test_scheduler.py | 43 +++++++++++++++ awx/main/tests/functional/test_instances.py | 4 +- awx/main/tests/unit/models/test_ha.py | 5 +- awx/main/tests/unit/test_capacity.py | 10 +++- awx/main/tests/unit/utils/test_common.py | 18 +++++- awx/main/utils/common.py | 14 +++++ 16 files changed, 159 insertions(+), 88 deletions(-) diff --git a/awx/main/managers.py b/awx/main/managers.py index 05ffb3ecbb..eaee03ed7e 100644 --- a/awx/main/managers.py +++ b/awx/main/managers.py @@ -10,6 +10,7 @@ from django.conf import settings from awx.main.utils.filters import SmartFilter from awx.main.utils.pglock import advisory_lock +from awx.main.utils.common import get_capacity_type from awx.main.constants import RECEPTOR_PENDING ___all__ = ['HostManager', 'InstanceManager', 'InstanceGroupManager', 'DeferJobCreatedManager'] @@ -160,7 +161,10 @@ class InstanceManager(models.Manager): from awx.main.management.commands.register_queue import RegisterQueue pod_ip = os.environ.get('MY_POD_IP') - registered = self.register(ip_address=pod_ip) + if settings.IS_K8S: + registered = self.register(ip_address=pod_ip, node_type='control') + else: + registered = self.register(ip_address=pod_ip) RegisterQueue(settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME, 100, 0, [], is_container_group=False).register() RegisterQueue(settings.DEFAULT_EXECUTION_QUEUE_NAME, 100, 0, [], is_container_group=True).register() return registered @@ -204,6 +208,8 @@ class InstanceGroupManager(models.Manager): if name not in graph: graph[name] = {} graph[name]['consumed_capacity'] = 0 + for capacity_type in ('execution', 'control'): + graph[name][f'consumed_{capacity_type}_capacity'] = 0 if breakdown: graph[name]['committed_capacity'] = 0 graph[name]['running_capacity'] = 0 @@ -239,6 +245,8 @@ class InstanceGroupManager(models.Manager): if group_name not in graph: self.zero_out_group(graph, group_name, breakdown) graph[group_name]['consumed_capacity'] += impact + capacity_type = get_capacity_type(t) + graph[group_name][f'consumed_{capacity_type}_capacity'] += impact if breakdown: graph[group_name]['committed_capacity'] += impact elif t.status == 'running': @@ -256,6 +264,8 @@ class InstanceGroupManager(models.Manager): if group_name not in graph: self.zero_out_group(graph, group_name, breakdown) graph[group_name]['consumed_capacity'] += impact + capacity_type = get_capacity_type(t) + graph[group_name][f'consumed_{capacity_type}_capacity'] += impact if breakdown: graph[group_name]['running_capacity'] += impact else: diff --git a/awx/main/models/ad_hoc_commands.py b/awx/main/models/ad_hoc_commands.py index f15af65f61..9873888981 100644 --- a/awx/main/models/ad_hoc_commands.py +++ b/awx/main/models/ad_hoc_commands.py @@ -152,10 +152,6 @@ class AdHocCommand(UnifiedJob, JobNotificationMixin): def is_container_group_task(self): return bool(self.instance_group and self.instance_group.is_container_group) - @property - def can_run_containerized(self): - return True - def get_absolute_url(self, request=None): return reverse('api:ad_hoc_command_detail', kwargs={'pk': self.pk}, request=request) diff --git a/awx/main/models/ha.py b/awx/main/models/ha.py index 00475254bd..aaf6e26835 100644 --- a/awx/main/models/ha.py +++ b/awx/main/models/ha.py @@ -269,10 +269,6 @@ class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin): def capacity(self): return sum([inst.capacity for inst in self.instances.all()]) - @property - def execution_capacity(self): - return sum([inst.capacity for inst in self.instances.filter(node_type__in=['hybrid', 'execution'])]) - @property def jobs_running(self): return UnifiedJob.objects.filter(status__in=('running', 'waiting'), instance_group=self).count() @@ -295,7 +291,7 @@ class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin): def fit_task_to_most_remaining_capacity_instance(task, instances): instance_most_capacity = None for i in instances: - if i.node_type == 'control': + if i.node_type not in (task.capacity_type, 'hybrid'): continue if i.remaining_capacity >= task.task_impact and ( instance_most_capacity is None or i.remaining_capacity > instance_most_capacity.remaining_capacity @@ -304,9 +300,11 @@ class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin): return instance_most_capacity @staticmethod - def find_largest_idle_instance(instances): + def find_largest_idle_instance(instances, capacity_type='execution'): largest_instance = None for i in instances: + if i.node_type not in (capacity_type, 'hybrid'): + continue if i.jobs_running == 0: if largest_instance is None: largest_instance = i diff --git a/awx/main/models/inventory.py b/awx/main/models/inventory.py index 63bb738779..2d5508d4d2 100644 --- a/awx/main/models/inventory.py +++ b/awx/main/models/inventory.py @@ -1214,10 +1214,6 @@ class InventoryUpdate(UnifiedJob, InventorySourceOptions, JobNotificationMixin, def is_container_group_task(self): return bool(self.instance_group and self.instance_group.is_container_group) - @property - def can_run_containerized(self): - return True - def _get_parent_field_name(self): return 'inventory_source' diff --git a/awx/main/models/jobs.py b/awx/main/models/jobs.py index 34c3610220..2d2f3ade16 100644 --- a/awx/main/models/jobs.py +++ b/awx/main/models/jobs.py @@ -743,10 +743,6 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana return "$hidden due to Ansible no_log flag$" return artifacts - @property - def can_run_containerized(self): - return True - @property def is_container_group_task(self): return bool(self.instance_group and self.instance_group.is_container_group) @@ -1236,10 +1232,6 @@ class SystemJob(UnifiedJob, SystemJobOptions, JobNotificationMixin): return UnpartitionedSystemJobEvent return SystemJobEvent - @property - def can_run_on_control_plane(self): - return True - @property def task_impact(self): return 5 diff --git a/awx/main/models/projects.py b/awx/main/models/projects.py index 1c34871205..d64e2483bd 100644 --- a/awx/main/models/projects.py +++ b/awx/main/models/projects.py @@ -553,10 +553,6 @@ class ProjectUpdate(UnifiedJob, ProjectOptions, JobNotificationMixin, TaskManage websocket_data.update(dict(project_id=self.project.id)) return websocket_data - @property - def can_run_on_control_plane(self): - return True - @property def event_class(self): if self.has_unpartitioned_events: diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index 01c5467768..f9f5d0a133 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -36,21 +36,21 @@ from awx.main.dispatch import get_local_queuename from awx.main.dispatch.control import Control as ControlDispatcher from awx.main.registrar import activity_stream_registrar from awx.main.models.mixins import ResourceMixin, TaskManagerUnifiedJobMixin, ExecutionEnvironmentMixin -from awx.main.utils import ( +from awx.main.utils.common import ( camelcase_to_underscore, get_model_for_type, - encrypt_dict, - decrypt_field, _inventory_updates, copy_model_by_class, copy_m2m_relationships, get_type_for_model, parse_yaml_or_json, getattr_dne, - polymorphic, schedule_task_manager, get_event_partition_epoch, + get_capacity_type, ) +from awx.main.utils.encryption import encrypt_dict, decrypt_field +from awx.main.utils import polymorphic from awx.main.constants import ACTIVE_STATES, CAN_CANCEL from awx.main.redact import UriCleaner, REPLACE_STR from awx.main.consumers import emit_channel_notification @@ -740,15 +740,8 @@ class UnifiedJob( raise NotImplementedError # Implement in subclasses. @property - def can_run_on_control_plane(self): - if settings.IS_K8S: - return False - - return True - - @property - def can_run_containerized(self): - return False + def capacity_type(self): + return get_capacity_type(self) def _get_parent_field_name(self): return 'unified_job_template' # Override in subclasses. diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index 99fd8015f7..07ac9f7bd3 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -87,7 +87,21 @@ class TaskManager: instances_by_hostname = {i.hostname: i for i in instances_partial} for rampart_group in InstanceGroup.objects.prefetch_related('instances'): - self.graph[rampart_group.name] = dict(graph=DependencyGraph(), capacity_total=rampart_group.execution_capacity, consumed_capacity=0, instances=[]) + self.graph[rampart_group.name] = dict( + graph=DependencyGraph(), + execution_capacity=0, + control_capacity=0, + consumed_capacity=0, + consumed_control_capacity=0, + consumed_execution_capacity=0, + instances=[], + ) + for instance in rampart_group.instances.all(): + if not instance.enabled: + continue + for capacity_type in ('control', 'execution'): + if instance.node_type in (capacity_type, 'hybrid'): + self.graph[rampart_group.name][f'{capacity_type}_capacity'] += instance.capacity for instance in rampart_group.instances.filter(enabled=True).order_by('hostname'): if instance.hostname in instances_by_hostname: self.graph[rampart_group.name]['instances'].append(instances_by_hostname[instance.hostname]) @@ -281,7 +295,7 @@ class TaskManager: task.instance_group = rampart_group if match is None: logger.warn('No available capacity to run containerized <{}>.'.format(task.log_format)) - elif task.can_run_containerized and any(ig.is_container_group for ig in task.preferred_instance_groups): + elif task.capacity_type == 'execution' and any(ig.is_container_group for ig in task.preferred_instance_groups): task.controller_node = match.hostname else: # project updates and inventory updates don't *actually* run in pods, so @@ -291,12 +305,15 @@ class TaskManager: else: task.instance_group = rampart_group task.execution_node = instance.hostname - try: - controller_node = Instance.choose_online_control_plane_node() - except IndexError: - logger.warning("No control plane nodes available to manage {}".format(task.log_format)) - return - task.controller_node = controller_node + if instance.node_type == 'execution': + try: + task.controller_node = Instance.choose_online_control_plane_node() + except IndexError: + logger.warning("No control plane nodes available to manage {}".format(task.log_format)) + return + else: + # control plane nodes will manage jobs locally for performance and resilience + task.controller_node = task.execution_node logger.debug('Submitting job {} to queue {} controlled by {}.'.format(task.log_format, task.execution_node, task.controller_node)) with disable_activity_stream(): task.celery_task_id = str(uuid.uuid4()) @@ -304,7 +321,7 @@ class TaskManager: task.log_lifecycle("waiting") if rampart_group is not None: - self.consume_capacity(task, rampart_group.name) + self.consume_capacity(task, rampart_group.name, instance=instance) def post_commit(): if task.status != 'failed' and type(task) is not WorkflowJob: @@ -493,24 +510,25 @@ class TaskManager: continue for rampart_group in preferred_instance_groups: - if task.can_run_containerized and rampart_group.is_container_group: + if task.capacity_type == 'execution' and rampart_group.is_container_group: self.graph[rampart_group.name]['graph'].add_job(task) self.start_task(task, rampart_group, task.get_jobs_fail_chain(), None) found_acceptable_queue = True break - if not task.can_run_on_control_plane: + # TODO: remove this after we have confidence that OCP control nodes are reporting node_type=control + if settings.IS_K8S and task.capacity_type == 'execution': logger.debug("Skipping group {}, task cannot run on control plane".format(rampart_group.name)) continue - remaining_capacity = self.get_remaining_capacity(rampart_group.name) - if task.task_impact > 0 and self.get_remaining_capacity(rampart_group.name) <= 0: + remaining_capacity = self.get_remaining_capacity(rampart_group.name, capacity_type=task.capacity_type) + if task.task_impact > 0 and remaining_capacity <= 0: logger.debug("Skipping group {}, remaining_capacity {} <= 0".format(rampart_group.name, remaining_capacity)) continue execution_instance = InstanceGroup.fit_task_to_most_remaining_capacity_instance( task, self.graph[rampart_group.name]['instances'] - ) or InstanceGroup.find_largest_idle_instance(self.graph[rampart_group.name]['instances']) + ) or InstanceGroup.find_largest_idle_instance(self.graph[rampart_group.name]['instances'], capacity_type=task.capacity_type) if execution_instance or rampart_group.is_container_group: if not rampart_group.is_container_group: @@ -581,16 +599,19 @@ class TaskManager: def calculate_capacity_consumed(self, tasks): self.graph = InstanceGroup.objects.capacity_values(tasks=tasks, graph=self.graph) - def consume_capacity(self, task, instance_group): + def consume_capacity(self, task, instance_group, instance=None): logger.debug( '{} consumed {} capacity units from {} with prior total of {}'.format( task.log_format, task.task_impact, instance_group, self.graph[instance_group]['consumed_capacity'] ) ) self.graph[instance_group]['consumed_capacity'] += task.task_impact + for capacity_type in ('control', 'execution'): + if instance is None or instance.node_type in ('hybrid', capacity_type): + self.graph[instance_group][f'consumed_{capacity_type}_capacity'] += task.task_impact - def get_remaining_capacity(self, instance_group): - return self.graph[instance_group]['capacity_total'] - self.graph[instance_group]['consumed_capacity'] + def get_remaining_capacity(self, instance_group, capacity_type='execution'): + return self.graph[instance_group][f'{capacity_type}_capacity'] - self.graph[instance_group][f'consumed_{capacity_type}_capacity'] def process_tasks(self, all_sorted_tasks): running_tasks = [t for t in all_sorted_tasks if t.status in ['waiting', 'running']] diff --git a/awx/main/tasks.py b/awx/main/tasks.py index fa65f97f40..47d55cab72 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -3001,18 +3001,18 @@ class AWXReceptorJob: execution_environment_params = self.task.build_execution_environment_params(self.task.instance, runner_params['private_data_dir']) self.runner_params['settings'].update(execution_environment_params) - def run(self, work_type=None): + def run(self): # We establish a connection to the Receptor socket receptor_ctl = get_receptor_ctl() try: - return self._run_internal(receptor_ctl, work_type=work_type) + return self._run_internal(receptor_ctl) finally: # Make sure to always release the work unit if we established it if self.unit_id is not None and settings.RECEPTOR_RELEASE_WORK: receptor_ctl.simple_command(f"work release {self.unit_id}") - def _run_internal(self, receptor_ctl, work_type=None): + def _run_internal(self, receptor_ctl): # Create a socketpair. Where the left side will be used for writing our payload # (private data dir, kwargs). The right side will be passed to Receptor for # reading. @@ -3024,13 +3024,9 @@ class AWXReceptorJob: # submit our work, passing # in the right side of our socketpair for reading. _kw = {} - work_type = work_type or self.work_type - if work_type == 'ansible-runner': + if self.work_type == 'ansible-runner': _kw['node'] = self.task.instance.execution_node - logger.debug(f'receptorctl.submit_work(node={_kw["node"]})') - else: - logger.debug(f'receptorctl.submit_work({work_type})') - result = receptor_ctl.submit_work(worktype=work_type, payload=sockout.makefile('rb'), params=self.receptor_params, **_kw) + result = receptor_ctl.submit_work(worktype=self.work_type, payload=sockout.makefile('rb'), params=self.receptor_params, **_kw) self.unit_id = result['unitid'] self.task.update_model(self.task.instance.pk, work_unit_id=result['unitid']) @@ -3136,18 +3132,11 @@ class AWXReceptorJob: def work_type(self): if self.task.instance.is_container_group_task: if self.credential: - work_type = 'kubernetes-runtime-auth' - else: - work_type = 'kubernetes-incluster-auth' - elif isinstance(self.task.instance, (Job, AdHocCommand)): - if self.task.instance.execution_node == self.task.instance.controller_node: - work_type = 'local' - else: - work_type = 'ansible-runner' - else: - work_type = 'local' - - return work_type + return 'kubernetes-runtime-auth' + return 'kubernetes-incluster-auth' + if self.task.instance.execution_node == settings.CLUSTER_HOST_ID or self.task.instance.execution_node == self.task.instance.controller_node: + return 'local' + return 'ansible-runner' @cleanup_new_process def cancel_watcher(self, processor_future): diff --git a/awx/main/tests/functional/conftest.py b/awx/main/tests/functional/conftest.py index b86edb90ec..7e2178ca4d 100644 --- a/awx/main/tests/functional/conftest.py +++ b/awx/main/tests/functional/conftest.py @@ -121,7 +121,7 @@ def run_computed_fields_right_away(request): @pytest.fixture @mock.patch.object(Project, "update", lambda self, **kwargs: None) -def project(instance, organization): +def project(organization): prj = Project.objects.create( name="test-proj", description="test-proj-desc", @@ -136,7 +136,7 @@ def project(instance, organization): @pytest.fixture @mock.patch.object(Project, "update", lambda self, **kwargs: None) -def manual_project(instance, organization): +def manual_project(organization): prj = Project.objects.create( name="test-manual-proj", description="manual-proj-desc", @@ -196,7 +196,7 @@ def instance(settings): @pytest.fixture -def organization(instance): +def organization(): return Organization.objects.create(name="test-org", description="test-org-desc") diff --git a/awx/main/tests/functional/task_management/test_scheduler.py b/awx/main/tests/functional/task_management/test_scheduler.py index 64dcc97415..d6794e4f77 100644 --- a/awx/main/tests/functional/task_management/test_scheduler.py +++ b/awx/main/tests/functional/task_management/test_scheduler.py @@ -7,6 +7,7 @@ from awx.main.scheduler import TaskManager from awx.main.scheduler.dependency_graph import DependencyGraph from awx.main.utils import encrypt_field from awx.main.models import WorkflowJobTemplate, JobTemplate, Job +from awx.main.models.ha import Instance, InstanceGroup @pytest.mark.django_db @@ -99,6 +100,48 @@ class TestJobLifeCycle: self.run_tm(tm, expect_schedule=[mock.call()]) wfjts[0].refresh_from_db() + @pytest.fixture + def control_instance(self): + '''Control instance in the controlplane automatic IG''' + ig = InstanceGroup.objects.create(name='controlplane') + inst = Instance.objects.create(hostname='control-1', node_type='control', capacity=500) + ig.instances.add(inst) + return inst + + @pytest.fixture + def execution_instance(self): + '''Execution node in the automatic default IG''' + ig = InstanceGroup.objects.create(name='default') + inst = Instance.objects.create(hostname='receptor-1', node_type='execution', capacity=500) + ig.instances.add(inst) + return inst + + def test_control_and_execution_instance(self, project, system_job_template, job_template, inventory_source, control_instance, execution_instance): + assert Instance.objects.count() == 2 + + pu = project.create_unified_job() + sj = system_job_template.create_unified_job() + job = job_template.create_unified_job() + inv_update = inventory_source.create_unified_job() + + all_ujs = (pu, sj, job, inv_update) + for uj in all_ujs: + uj.signal_start() + + tm = TaskManager() + self.run_tm(tm) + + for uj in all_ujs: + uj.refresh_from_db() + assert uj.status == 'waiting' + + for uj in (pu, sj): # control plane jobs + assert uj.capacity_type == 'control' + assert [uj.execution_node, uj.controller_node] == [control_instance.hostname, control_instance.hostname], uj + for uj in (job, inv_update): # user-space jobs + assert uj.capacity_type == 'execution' + assert [uj.execution_node, uj.controller_node] == [execution_instance.hostname, control_instance.hostname], uj + @pytest.mark.django_db def test_single_jt_multi_job_launch_blocks_last(default_instance_group, job_template_factory, mocker): diff --git a/awx/main/tests/functional/test_instances.py b/awx/main/tests/functional/test_instances.py index 7b8a2f41ab..65428886db 100644 --- a/awx/main/tests/functional/test_instances.py +++ b/awx/main/tests/functional/test_instances.py @@ -68,7 +68,7 @@ class TestPolicyTaskScheduling: @pytest.mark.django_db -def test_instance_dup(org_admin, organization, project, instance_factory, instance_group_factory, get, system_auditor): +def test_instance_dup(org_admin, organization, project, instance_factory, instance_group_factory, get, system_auditor, instance): i1 = instance_factory("i1") i2 = instance_factory("i2") i3 = instance_factory("i3") @@ -83,7 +83,7 @@ def test_instance_dup(org_admin, organization, project, instance_factory, instan api_num_instances_oa = list(list_response2.data.items())[0][1] assert actual_num_instances == api_num_instances_auditor - # Note: The org_admin will not see the default 'tower' node because it is not in it's group, as expected + # Note: The org_admin will not see the default 'tower' node (instance fixture) because it is not in it's group, as expected assert api_num_instances_oa == (actual_num_instances - 1) diff --git a/awx/main/tests/unit/models/test_ha.py b/awx/main/tests/unit/models/test_ha.py index ec71a47fc2..4a07cafd2d 100644 --- a/awx/main/tests/unit/models/test_ha.py +++ b/awx/main/tests/unit/models/test_ha.py @@ -17,8 +17,9 @@ def test_capacity_adjustment_no_save(capacity_adjustment): def T(impact): - j = mock.Mock(spec_set=['task_impact']) + j = mock.Mock(spec_set=['task_impact', 'capacity_type']) j.task_impact = impact + j.capacity_type = 'execution' return j @@ -35,11 +36,13 @@ def Is(param): inst = Mock() inst.capacity = capacity inst.jobs_running = jobs_running + inst.node_type = 'execution' instances.append(inst) else: for i in param: inst = Mock() inst.remaining_capacity = i + inst.node_type = 'execution' instances.append(inst) return instances diff --git a/awx/main/tests/unit/test_capacity.py b/awx/main/tests/unit/test_capacity.py index 8f35210088..fab27c6c76 100644 --- a/awx/main/tests/unit/test_capacity.py +++ b/awx/main/tests/unit/test_capacity.py @@ -3,10 +3,16 @@ import pytest from awx.main.models import InstanceGroup +class FakeMeta(object): + model_name = 'job' + + class FakeObject(object): def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) + self._meta = FakeMeta() + self._meta.concrete_model = self class Job(FakeObject): @@ -85,7 +91,7 @@ def test_offline_node_running(sample_cluster): ig_small.instance_list[0].capacity = 0 tasks = [Job(status='running', execution_node='i1', instance_group=ig_small)] capacities = InstanceGroup.objects.capacity_values(qs=[default, ig_large, ig_small], tasks=tasks) - assert capacities['ig_small']['consumed_capacity'] == 43 + assert capacities['ig_small']['consumed_execution_capacity'] == 43 def test_offline_node_waiting(sample_cluster): @@ -96,7 +102,7 @@ def test_offline_node_waiting(sample_cluster): ig_small.instance_list[0].capacity = 0 tasks = [Job(status='waiting', instance_group=ig_small)] capacities = InstanceGroup.objects.capacity_values(qs=[default, ig_large, ig_small], tasks=tasks) - assert capacities['ig_small']['consumed_capacity'] == 43 + assert capacities['ig_small']['consumed_execution_capacity'] == 43 def test_RBAC_reduced_filter(sample_cluster): diff --git a/awx/main/tests/unit/utils/test_common.py b/awx/main/tests/unit/utils/test_common.py index 98aaefea2c..65d2b5d1f3 100644 --- a/awx/main/tests/unit/utils/test_common.py +++ b/awx/main/tests/unit/utils/test_common.py @@ -104,18 +104,32 @@ def test_get_type_for_model(model, name): assert common.get_type_for_model(model) == name -@pytest.mark.django_db def test_get_model_for_invalid_type(): with pytest.raises(LookupError): common.get_model_for_type('foobar') -@pytest.mark.django_db @pytest.mark.parametrize("model_type,model_class", [(name, cls) for cls, name in TEST_MODELS]) def test_get_model_for_valid_type(model_type, model_class): assert common.get_model_for_type(model_type) == model_class +@pytest.mark.parametrize("model_type,model_class", [(name, cls) for cls, name in TEST_MODELS]) +def test_get_capacity_type(model_type, model_class): + if model_type in ('job', 'ad_hoc_command', 'inventory_update', 'job_template'): + expectation = 'execution' + elif model_type in ('project_update', 'system_job'): + expectation = 'control' + else: + expectation = None + if model_type in ('unified_job', 'unified_job_template', 'inventory'): + with pytest.raises(RuntimeError): + common.get_capacity_type(model_class) + else: + assert common.get_capacity_type(model_class) == expectation + assert common.get_capacity_type(model_class()) == expectation + + @pytest.fixture def memoized_function(mocker, mock_cache): with mock.patch('awx.main.utils.common.get_memoize_cache', return_value=mock_cache): diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index dc324efaf5..6ed4922196 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -574,6 +574,20 @@ def get_model_for_type(type_name): return apps.get_model(use_app, model_str) +def get_capacity_type(uj): + '''Used for UnifiedJob.capacity_type property, static method will work for partial objects''' + model_name = uj._meta.concrete_model._meta.model_name + if model_name in ('job', 'inventoryupdate', 'adhoccommand', 'jobtemplate', 'inventorysource'): + return 'execution' + elif model_name == 'workflowjob': + return None + elif model_name.startswith('unified'): + raise RuntimeError(f'Capacity type is undefined for {model_name} model') + elif model_name in ('projectupdate', 'systemjob', 'project', 'systemjobtemplate'): + return 'control' + raise RuntimeError(f'Capacity type does not apply to {model_name} model') + + def prefetch_page_capabilities(model, page, prefetch_list, user): """ Given a `page` list of objects, a nested dictionary of user_capabilities