mirror of
https://github.com/ansible/awx.git
synced 2026-02-21 21:20:08 -03:30
Apply policy task more selectively
This commit is contained in:
@@ -221,7 +221,46 @@ class PasswordFieldsModel(BaseModel):
|
|||||||
update_fields.append(field)
|
update_fields.append(field)
|
||||||
|
|
||||||
|
|
||||||
class PrimordialModel(CreatedModifiedModel):
|
class HasEditsMixin(BaseModel):
|
||||||
|
"""Mixin which will keep the versions of field values from last edit
|
||||||
|
so we can tell if current model has unsaved changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_editable_fields(cls):
|
||||||
|
fds = set([])
|
||||||
|
for field in cls._meta.concrete_fields:
|
||||||
|
if hasattr(field, 'attname'):
|
||||||
|
if field.attname == 'id':
|
||||||
|
continue
|
||||||
|
elif field.attname.endswith('ptr_id'):
|
||||||
|
# polymorphic fields should always be non-editable, see:
|
||||||
|
# https://github.com/django-polymorphic/django-polymorphic/issues/349
|
||||||
|
continue
|
||||||
|
if getattr(field, 'editable', True):
|
||||||
|
fds.add(field.attname)
|
||||||
|
return fds
|
||||||
|
|
||||||
|
def _get_fields_snapshot(self, fields_set=None):
|
||||||
|
new_values = {}
|
||||||
|
if fields_set is None:
|
||||||
|
fields_set = self._get_editable_fields()
|
||||||
|
for attr, val in self.__dict__.items():
|
||||||
|
if attr in fields_set:
|
||||||
|
new_values[attr] = val
|
||||||
|
return new_values
|
||||||
|
|
||||||
|
def _values_have_edits(self, new_values):
|
||||||
|
return any(
|
||||||
|
new_values.get(fd_name, None) != self._prior_values_store.get(fd_name, None)
|
||||||
|
for fd_name in new_values.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrimordialModel(HasEditsMixin, CreatedModifiedModel):
|
||||||
'''
|
'''
|
||||||
Common model for all object types that have these standard fields
|
Common model for all object types that have these standard fields
|
||||||
must use a subclass CommonModel or CommonModelNameNotUnique though
|
must use a subclass CommonModel or CommonModelNameNotUnique though
|
||||||
@@ -270,42 +309,13 @@ class PrimordialModel(CreatedModifiedModel):
|
|||||||
update_fields.append('created_by')
|
update_fields.append('created_by')
|
||||||
# Update modified_by if any editable fields have changed
|
# Update modified_by if any editable fields have changed
|
||||||
new_values = self._get_fields_snapshot()
|
new_values = self._get_fields_snapshot()
|
||||||
if (not self.pk and not self.modified_by) or self.has_user_edits(new_values):
|
if (not self.pk and not self.modified_by) or self._values_have_edits(new_values):
|
||||||
self.modified_by = user
|
self.modified_by = user
|
||||||
if 'modified_by' not in update_fields:
|
if 'modified_by' not in update_fields:
|
||||||
update_fields.append('modified_by')
|
update_fields.append('modified_by')
|
||||||
super(PrimordialModel, self).save(*args, **kwargs)
|
super(PrimordialModel, self).save(*args, **kwargs)
|
||||||
self._prior_values_store = new_values
|
self._prior_values_store = new_values
|
||||||
|
|
||||||
def has_user_edits(self, new_values):
|
|
||||||
return any(
|
|
||||||
new_values.get(fd_name, None) != self._prior_values_store.get(fd_name, None)
|
|
||||||
for fd_name in new_values.keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_editable_fields(cls):
|
|
||||||
fds = set([])
|
|
||||||
for field in cls._meta.concrete_fields:
|
|
||||||
if hasattr(field, 'attname'):
|
|
||||||
if field.attname == 'id':
|
|
||||||
continue
|
|
||||||
elif field.attname.endswith('ptr_id'):
|
|
||||||
# polymorphic fields should always be non-editable, see:
|
|
||||||
# https://github.com/django-polymorphic/django-polymorphic/issues/349
|
|
||||||
continue
|
|
||||||
if getattr(field, 'editable', True):
|
|
||||||
fds.add(field.attname)
|
|
||||||
return fds
|
|
||||||
|
|
||||||
def _get_fields_snapshot(self):
|
|
||||||
new_values = {}
|
|
||||||
editable_set = self._get_editable_fields()
|
|
||||||
for attr, val in self.__dict__.items():
|
|
||||||
if attr in editable_set:
|
|
||||||
new_values[attr] = val
|
|
||||||
return new_values
|
|
||||||
|
|
||||||
def clean_description(self):
|
def clean_description(self):
|
||||||
# Description should always be empty string, never null.
|
# Description should always be empty string, never null.
|
||||||
return self.description or ''
|
return self.description or ''
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from awx import __version__ as awx_application_version
|
|||||||
from awx.api.versioning import reverse
|
from awx.api.versioning import reverse
|
||||||
from awx.main.managers import InstanceManager, InstanceGroupManager
|
from awx.main.managers import InstanceManager, InstanceGroupManager
|
||||||
from awx.main.fields import JSONField
|
from awx.main.fields import JSONField
|
||||||
from awx.main.models.base import BaseModel
|
from awx.main.models.base import BaseModel, HasEditsMixin
|
||||||
from awx.main.models.inventory import InventoryUpdate
|
from awx.main.models.inventory import InventoryUpdate
|
||||||
from awx.main.models.jobs import Job
|
from awx.main.models.jobs import Job
|
||||||
from awx.main.models.projects import ProjectUpdate
|
from awx.main.models.projects import ProjectUpdate
|
||||||
@@ -39,7 +39,28 @@ def validate_queuename(v):
|
|||||||
raise ValidationError(_(six.text_type('{} contains unsupported characters')).format(v))
|
raise ValidationError(_(six.text_type('{} contains unsupported characters')).format(v))
|
||||||
|
|
||||||
|
|
||||||
class Instance(BaseModel):
|
class HasPolicyEditsMixin(HasEditsMixin):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
r = super(BaseModel, self).__init__(*args, **kwargs)
|
||||||
|
self._prior_values_store = self._get_fields_snapshot()
|
||||||
|
return r
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
super(BaseModel, self).save(*args, **kwargs)
|
||||||
|
self._prior_values_store = self._get_fields_snapshot()
|
||||||
|
|
||||||
|
def has_policy_changes(self):
|
||||||
|
if not hasattr(self, 'POLICY_FIELDS'):
|
||||||
|
raise RuntimeError('HasPolicyEditsMixin Model needs to set POLICY_FIELDS')
|
||||||
|
new_values = self._get_fields_snapshot(fields_set=self.POLICY_FIELDS)
|
||||||
|
return self._values_have_edits(new_values)
|
||||||
|
|
||||||
|
|
||||||
|
class Instance(HasPolicyEditsMixin, BaseModel):
|
||||||
"""A model representing an AWX instance running against this database."""
|
"""A model representing an AWX instance running against this database."""
|
||||||
objects = InstanceManager()
|
objects = InstanceManager()
|
||||||
|
|
||||||
@@ -87,6 +108,8 @@ class Instance(BaseModel):
|
|||||||
class Meta:
|
class Meta:
|
||||||
app_label = 'main'
|
app_label = 'main'
|
||||||
|
|
||||||
|
POLICY_FIELDS = frozenset(('managed_by_policy', 'hostname', 'capacity_adjustment'))
|
||||||
|
|
||||||
def get_absolute_url(self, request=None):
|
def get_absolute_url(self, request=None):
|
||||||
return reverse('api:instance_detail', kwargs={'pk': self.pk}, request=request)
|
return reverse('api:instance_detail', kwargs={'pk': self.pk}, request=request)
|
||||||
|
|
||||||
@@ -144,7 +167,7 @@ class Instance(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceGroup(BaseModel, RelatedJobsMixin):
|
class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin):
|
||||||
"""A model representing a Queue/Group of AWX Instances."""
|
"""A model representing a Queue/Group of AWX Instances."""
|
||||||
objects = InstanceGroupManager()
|
objects = InstanceGroupManager()
|
||||||
|
|
||||||
@@ -179,6 +202,10 @@ class InstanceGroup(BaseModel, RelatedJobsMixin):
|
|||||||
help_text=_("List of exact-match Instances that will always be automatically assigned to this group")
|
help_text=_("List of exact-match Instances that will always be automatically assigned to this group")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
POLICY_FIELDS = frozenset((
|
||||||
|
'policy_instance_list', 'policy_instance_minimum', 'policy_instance_percentage', 'controller'
|
||||||
|
))
|
||||||
|
|
||||||
def get_absolute_url(self, request=None):
|
def get_absolute_url(self, request=None):
|
||||||
return reverse('api:instance_group_detail', kwargs={'pk': self.pk}, request=request)
|
return reverse('api:instance_group_detail', kwargs={'pk': self.pk}, request=request)
|
||||||
|
|
||||||
@@ -259,29 +286,31 @@ class JobOrigin(models.Model):
|
|||||||
app_label = 'main'
|
app_label = 'main'
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_save, sender=InstanceGroup)
|
def schedule_policy_task():
|
||||||
def on_instance_group_saved(sender, instance, created=False, raw=False, **kwargs):
|
|
||||||
from awx.main.tasks import apply_cluster_membership_policies
|
from awx.main.tasks import apply_cluster_membership_policies
|
||||||
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
|
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(post_save, sender=InstanceGroup)
|
||||||
|
def on_instance_group_saved(sender, instance, created=False, raw=False, **kwargs):
|
||||||
|
if created or instance.has_policy_changes():
|
||||||
|
schedule_policy_task()
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_save, sender=Instance)
|
@receiver(post_save, sender=Instance)
|
||||||
def on_instance_saved(sender, instance, created=False, raw=False, **kwargs):
|
def on_instance_saved(sender, instance, created=False, raw=False, **kwargs):
|
||||||
if created:
|
if created or instance.has_policy_changes():
|
||||||
from awx.main.tasks import apply_cluster_membership_policies
|
schedule_policy_task()
|
||||||
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
|
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_delete, sender=InstanceGroup)
|
@receiver(post_delete, sender=InstanceGroup)
|
||||||
def on_instance_group_deleted(sender, instance, using, **kwargs):
|
def on_instance_group_deleted(sender, instance, using, **kwargs):
|
||||||
from awx.main.tasks import apply_cluster_membership_policies
|
schedule_policy_task()
|
||||||
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
|
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_delete, sender=Instance)
|
@receiver(post_delete, sender=Instance)
|
||||||
def on_instance_deleted(sender, instance, using, **kwargs):
|
def on_instance_deleted(sender, instance, using, **kwargs):
|
||||||
from awx.main.tasks import apply_cluster_membership_policies
|
schedule_policy_task()
|
||||||
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
|
|
||||||
|
|
||||||
|
|
||||||
# Unfortunately, the signal can't just be connected against UnifiedJob; it
|
# Unfortunately, the signal can't just be connected against UnifiedJob; it
|
||||||
|
|||||||
@@ -1,15 +1,66 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import mock
|
||||||
|
|
||||||
from awx.main.models import AdHocCommand, InventoryUpdate, Job, JobTemplate, ProjectUpdate, Instance
|
from awx.main.models import AdHocCommand, InventoryUpdate, Job, JobTemplate, ProjectUpdate
|
||||||
|
from awx.main.models.ha import Instance, InstanceGroup
|
||||||
from awx.main.tasks import apply_cluster_membership_policies
|
from awx.main.tasks import apply_cluster_membership_policies
|
||||||
from awx.api.versioning import reverse
|
from awx.api.versioning import reverse
|
||||||
|
|
||||||
|
from django.utils.timezone import now
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_default_tower_instance_group(default_instance_group, job_factory):
|
def test_default_tower_instance_group(default_instance_group, job_factory):
|
||||||
assert default_instance_group in job_factory().preferred_instance_groups
|
assert default_instance_group in job_factory().preferred_instance_groups
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
class TestPolicyTaskScheduling:
|
||||||
|
"""Tests make assertions about when the policy task gets scheduled"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('field, value, expect', [
|
||||||
|
('name', 'foo-bar-foo-bar', False),
|
||||||
|
('policy_instance_percentage', 35, True),
|
||||||
|
('policy_instance_minimum', 3, True),
|
||||||
|
('policy_instance_list', ['bar?'], True),
|
||||||
|
('modified', now(), False)
|
||||||
|
])
|
||||||
|
def test_policy_task_ran_for_ig_when_needed(self, instance_group_factory, field, value, expect):
|
||||||
|
# always run on instance group creation
|
||||||
|
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
|
||||||
|
ig = InstanceGroup.objects.create(name='foo')
|
||||||
|
mock_policy.assert_called_once()
|
||||||
|
# selectively run on instance group modification
|
||||||
|
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
|
||||||
|
setattr(ig, field, value)
|
||||||
|
ig.save()
|
||||||
|
if expect:
|
||||||
|
mock_policy.assert_called_once()
|
||||||
|
else:
|
||||||
|
mock_policy.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('field, value, expect', [
|
||||||
|
('hostname', 'foo-bar-foo-bar', True),
|
||||||
|
('managed_by_policy', False, True),
|
||||||
|
('enabled', False, False),
|
||||||
|
('capacity_adjustment', 0.42, True),
|
||||||
|
('capacity', 42, False)
|
||||||
|
])
|
||||||
|
def test_policy_task_ran_for_instance_when_needed(self, instance_group_factory, field, value, expect):
|
||||||
|
# always run on instance group creation
|
||||||
|
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
|
||||||
|
inst = Instance.objects.create(hostname='foo')
|
||||||
|
mock_policy.assert_called_once()
|
||||||
|
# selectively run on instance group modification
|
||||||
|
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
|
||||||
|
setattr(inst, field, value)
|
||||||
|
inst.save()
|
||||||
|
if expect:
|
||||||
|
mock_policy.assert_called_once()
|
||||||
|
else:
|
||||||
|
mock_policy.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_instance_dup(org_admin, organization, project, instance_factory, instance_group_factory, get, system_auditor):
|
def test_instance_dup(org_admin, organization, project, instance_factory, instance_group_factory, get, system_auditor):
|
||||||
i1 = instance_factory("i1")
|
i1 = instance_factory("i1")
|
||||||
@@ -167,18 +218,23 @@ def test_policy_instance_list_manually_assigned(instance_factory, instance_group
|
|||||||
def test_policy_instance_list_explicitly_pinned(instance_factory, instance_group_factory):
|
def test_policy_instance_list_explicitly_pinned(instance_factory, instance_group_factory):
|
||||||
i1 = instance_factory("i1")
|
i1 = instance_factory("i1")
|
||||||
i2 = instance_factory("i2")
|
i2 = instance_factory("i2")
|
||||||
i2.managed_by_policy = False
|
|
||||||
i2.save()
|
|
||||||
ig_1 = instance_group_factory("ig1", percentage=100, minimum=2)
|
ig_1 = instance_group_factory("ig1", percentage=100, minimum=2)
|
||||||
ig_2 = instance_group_factory("ig2")
|
ig_2 = instance_group_factory("ig2")
|
||||||
ig_2.policy_instance_list = [i2.hostname]
|
ig_2.policy_instance_list = [i2.hostname]
|
||||||
ig_2.save()
|
ig_2.save()
|
||||||
|
|
||||||
|
# without being marked as manual, i2 will be picked up by ig_1
|
||||||
apply_cluster_membership_policies()
|
apply_cluster_membership_policies()
|
||||||
assert len(ig_1.instances.all()) == 1
|
assert set(ig_1.instances.all()) == set([i1, i2])
|
||||||
assert i1 in ig_1.instances.all()
|
assert set(ig_2.instances.all()) == set([i2])
|
||||||
assert i2 not in ig_1.instances.all()
|
|
||||||
assert len(ig_2.instances.all()) == 1
|
i2.managed_by_policy = False
|
||||||
assert i2 in ig_2.instances.all()
|
i2.save()
|
||||||
|
|
||||||
|
# after marking as manual, i2 no longer available for ig_1
|
||||||
|
apply_cluster_membership_policies()
|
||||||
|
assert set(ig_1.instances.all()) == set([i1])
|
||||||
|
assert set(ig_2.instances.all()) == set([i2])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
|||||||
Reference in New Issue
Block a user