diff --git a/awx/main/access.py b/awx/main/access.py index 7be1838017..70346d8b4b 100644 --- a/awx/main/access.py +++ b/awx/main/access.py @@ -1044,6 +1044,8 @@ class JobTemplateAccess(BaseAccess): self.check_license(feature='system_tracking') if obj.survey_enabled: self.check_license(feature='surveys') + if Instance.objects.active_count() > 1: + self.check_license(feature='ha') # Super users can start any job if self.user.is_superuser: diff --git a/awx/main/managers.py b/awx/main/managers.py index 176deb9483..c054584b0c 100644 --- a/awx/main/managers.py +++ b/awx/main/managers.py @@ -36,6 +36,10 @@ class InstanceManager(models.Manager): return node[0] raise RuntimeError("No instance found with the current cluster host id") + def active_count(self): + """Return count of active Tower nodes for licensing.""" + return self.all().count() + def my_role(self): # NOTE: TODO: Likely to repurpose this once standalone ramparts are a thing return "tower" diff --git a/awx/main/tests/unit/test_access.py b/awx/main/tests/unit/test_access.py index 650ed19864..fa6c34b95e 100644 --- a/awx/main/tests/unit/test_access.py +++ b/awx/main/tests/unit/test_access.py @@ -10,7 +10,8 @@ from awx.main.access import ( JobTemplateAccess, WorkflowJobTemplateAccess, ) -from awx.main.models import Credential, Inventory, Project, Role, Organization +from awx.conf.license import LicenseForbids +from awx.main.models import Credential, Inventory, Project, Role, Organization, Instance @pytest.fixture @@ -106,6 +107,18 @@ def test_jt_add_scan_job_check(job_template_with_ids, user_unit): 'job_type': 'scan' }) +def mock_raise_license_forbids(self, add_host=False, feature=None, check_expiration=True): + raise LicenseForbids("Feature not enabled") + +def mock_raise_none(self, add_host=False, feature=None, check_expiration=True): + return None + +def test_jt_can_start_ha(job_template_with_ids): + with mock.patch.object(Instance.objects, 'active_count', return_value=2): + with mock.patch('awx.main.access.BaseAccess.check_license', new=mock_raise_license_forbids): + with pytest.raises(LicenseForbids): + JobTemplateAccess(user_unit).can_start(job_template_with_ids) + def test_jt_can_add_bad_data(user_unit): "Assure that no server errors are returned if we call JT can_add with bad data" access = JobTemplateAccess(user_unit)