diff --git a/awx/main/tests/unit/utils/test_ha.py b/awx/main/tests/unit/utils/test_ha.py index 94cb7d3606..f73f8e908a 100644 --- a/awx/main/tests/unit/utils/test_ha.py +++ b/awx/main/tests/unit/utils/test_ha.py @@ -18,7 +18,8 @@ from awx.main.utils.ha import ( class TestAddRemoveCeleryWorkerQueues(): @pytest.fixture def instance_generator(self, mocker): - def fn(groups=['east', 'west', 'north', 'south'], hostname='east-1'): + def fn(hostname='east-1'): + groups=['east', 'west', 'north', 'south'] instance = mocker.MagicMock() instance.hostname = hostname instance.rampart_groups = mocker.MagicMock() @@ -40,29 +41,29 @@ class TestAddRemoveCeleryWorkerQueues(): app.control.cancel_consumer = mocker.MagicMock() return app - @pytest.mark.parametrize("broadcast_queues,static_queues,_worker_queues,groups,hostname,added_expected,removed_expected", [ - (['tower_broadcast_all'], ['east', 'west'], ['east', 'west', 'east-1'], [], 'east-1', ['tower_broadcast_all_east-1'], []), - ([], [], ['east', 'west', 'east-1'], ['east', 'west'], 'east-1', [], []), - ([], [], ['east', 'west'], ['east', 'west'], 'east-1', ['east-1'], []), - ([], [], [], ['east', 'west'], 'east-1', ['east', 'west', 'east-1'], []), - ([], [], ['china', 'russia'], ['east', 'west'], 'east-1', ['east', 'west', 'east-1'], ['china', 'russia']), + @pytest.mark.parametrize("broadcast_queues,static_queues,_worker_queues,hostname,added_expected,removed_expected", [ + (['tower_broadcast_all'], ['east', 'west'], ['east', 'west', 'east-1'], 'east-1', ['tower_broadcast_all_east-1'], []), + ([], [], ['east', 'west', 'east-1'], 'east-1', [], ['east', 'west']), + ([], [], ['east', 'west'], 'east-1', ['east-1'], ['east', 'west']), + ([], [], [], 'east-1', ['east-1'], []), + ([], [], ['china', 'russia'], 'east-1', [ 'east-1'], ['china', 'russia']), ]) def test__add_remove_celery_worker_queues_noop(self, mock_app, instance_generator, worker_queues_generator, broadcast_queues, static_queues, _worker_queues, - groups, hostname, + hostname, added_expected, removed_expected): - instance = instance_generator(groups=groups, hostname=hostname) + instance = instance_generator(hostname=hostname) worker_queues = worker_queues_generator(_worker_queues) with nested( mock.patch('awx.main.utils.ha.settings.AWX_CELERY_QUEUES_STATIC', static_queues), mock.patch('awx.main.utils.ha.settings.AWX_CELERY_BCAST_QUEUES_STATIC', broadcast_queues), mock.patch('awx.main.utils.ha.settings.CLUSTER_HOST_ID', hostname)): (added_queues, removed_queues) = _add_remove_celery_worker_queues(mock_app, [instance], worker_queues, hostname) - assert set(added_queues) == set(added_expected) - assert set(removed_queues) == set(removed_expected) + assert set(added_expected) == set(added_queues) + assert set(removed_expected) == set(removed_queues) class TestUpdateCeleryWorkerRouter(): diff --git a/awx/main/utils/ha.py b/awx/main/utils/ha.py index 49421ad4cb..dd629ca24d 100644 --- a/awx/main/utils/ha.py +++ b/awx/main/utils/ha.py @@ -17,14 +17,11 @@ def construct_bcast_queue_name(common_name): def _add_remove_celery_worker_queues(app, controlled_instances, worker_queues, worker_name): removed_queues = [] added_queues = [] - ig_names = set() hostnames = set([instance.hostname for instance in controlled_instances]) - for instance in controlled_instances: - ig_names.update(instance.rampart_groups.values_list('name', flat=True)) worker_queue_names = set([q['name'] for q in worker_queues]) bcast_queue_names = set([construct_bcast_queue_name(n) for n in settings.AWX_CELERY_BCAST_QUEUES_STATIC]) - all_queue_names = ig_names | hostnames | set(settings.AWX_CELERY_QUEUES_STATIC) + all_queue_names = hostnames | set(settings.AWX_CELERY_QUEUES_STATIC) desired_queues = bcast_queue_names | (all_queue_names if instance.enabled else set()) # Remove queues @@ -33,7 +30,7 @@ def _add_remove_celery_worker_queues(app, controlled_instances, worker_queues, w app.control.cancel_consumer(queue_name.encode("utf8"), reply=True, destination=[worker_name]) removed_queues.append(queue_name.encode("utf8")) - # Add queues for instance and instance groups + # Add queues for instances for queue_name in all_queue_names: if queue_name not in worker_queue_names: app.control.add_consumer(queue_name.encode("utf8"), reply=True, destination=[worker_name])