Apply policy task more selectively

This commit is contained in:
AlanCoding
2018-08-01 09:49:06 -04:00
parent 6f54f59485
commit a99ebbb02f
3 changed files with 146 additions and 51 deletions

View File

@@ -221,7 +221,46 @@ class PasswordFieldsModel(BaseModel):
update_fields.append(field) update_fields.append(field)
class PrimordialModel(CreatedModifiedModel): class HasEditsMixin(BaseModel):
"""Mixin which will keep the versions of field values from last edit
so we can tell if current model has unsaved changes.
"""
class Meta:
abstract = True
@classmethod
def _get_editable_fields(cls):
fds = set([])
for field in cls._meta.concrete_fields:
if hasattr(field, 'attname'):
if field.attname == 'id':
continue
elif field.attname.endswith('ptr_id'):
# polymorphic fields should always be non-editable, see:
# https://github.com/django-polymorphic/django-polymorphic/issues/349
continue
if getattr(field, 'editable', True):
fds.add(field.attname)
return fds
def _get_fields_snapshot(self, fields_set=None):
new_values = {}
if fields_set is None:
fields_set = self._get_editable_fields()
for attr, val in self.__dict__.items():
if attr in fields_set:
new_values[attr] = val
return new_values
def _values_have_edits(self, new_values):
return any(
new_values.get(fd_name, None) != self._prior_values_store.get(fd_name, None)
for fd_name in new_values.keys()
)
class PrimordialModel(HasEditsMixin, CreatedModifiedModel):
''' '''
Common model for all object types that have these standard fields Common model for all object types that have these standard fields
must use a subclass CommonModel or CommonModelNameNotUnique though must use a subclass CommonModel or CommonModelNameNotUnique though
@@ -270,42 +309,13 @@ class PrimordialModel(CreatedModifiedModel):
update_fields.append('created_by') update_fields.append('created_by')
# Update modified_by if any editable fields have changed # Update modified_by if any editable fields have changed
new_values = self._get_fields_snapshot() new_values = self._get_fields_snapshot()
if (not self.pk and not self.modified_by) or self.has_user_edits(new_values): if (not self.pk and not self.modified_by) or self._values_have_edits(new_values):
self.modified_by = user self.modified_by = user
if 'modified_by' not in update_fields: if 'modified_by' not in update_fields:
update_fields.append('modified_by') update_fields.append('modified_by')
super(PrimordialModel, self).save(*args, **kwargs) super(PrimordialModel, self).save(*args, **kwargs)
self._prior_values_store = new_values self._prior_values_store = new_values
def has_user_edits(self, new_values):
return any(
new_values.get(fd_name, None) != self._prior_values_store.get(fd_name, None)
for fd_name in new_values.keys()
)
@classmethod
def _get_editable_fields(cls):
fds = set([])
for field in cls._meta.concrete_fields:
if hasattr(field, 'attname'):
if field.attname == 'id':
continue
elif field.attname.endswith('ptr_id'):
# polymorphic fields should always be non-editable, see:
# https://github.com/django-polymorphic/django-polymorphic/issues/349
continue
if getattr(field, 'editable', True):
fds.add(field.attname)
return fds
def _get_fields_snapshot(self):
new_values = {}
editable_set = self._get_editable_fields()
for attr, val in self.__dict__.items():
if attr in editable_set:
new_values[attr] = val
return new_values
def clean_description(self): def clean_description(self):
# Description should always be empty string, never null. # Description should always be empty string, never null.
return self.description or '' return self.description or ''

View File

@@ -19,7 +19,7 @@ from awx import __version__ as awx_application_version
from awx.api.versioning import reverse from awx.api.versioning import reverse
from awx.main.managers import InstanceManager, InstanceGroupManager from awx.main.managers import InstanceManager, InstanceGroupManager
from awx.main.fields import JSONField from awx.main.fields import JSONField
from awx.main.models.base import BaseModel from awx.main.models.base import BaseModel, HasEditsMixin
from awx.main.models.inventory import InventoryUpdate from awx.main.models.inventory import InventoryUpdate
from awx.main.models.jobs import Job from awx.main.models.jobs import Job
from awx.main.models.projects import ProjectUpdate from awx.main.models.projects import ProjectUpdate
@@ -39,7 +39,28 @@ def validate_queuename(v):
raise ValidationError(_(six.text_type('{} contains unsupported characters')).format(v)) raise ValidationError(_(six.text_type('{} contains unsupported characters')).format(v))
class Instance(BaseModel): class HasPolicyEditsMixin(HasEditsMixin):
class Meta:
abstract = True
def __init__(self, *args, **kwargs):
r = super(BaseModel, self).__init__(*args, **kwargs)
self._prior_values_store = self._get_fields_snapshot()
return r
def save(self, *args, **kwargs):
super(BaseModel, self).save(*args, **kwargs)
self._prior_values_store = self._get_fields_snapshot()
def has_policy_changes(self):
if not hasattr(self, 'POLICY_FIELDS'):
raise RuntimeError('HasPolicyEditsMixin Model needs to set POLICY_FIELDS')
new_values = self._get_fields_snapshot(fields_set=self.POLICY_FIELDS)
return self._values_have_edits(new_values)
class Instance(HasPolicyEditsMixin, BaseModel):
"""A model representing an AWX instance running against this database.""" """A model representing an AWX instance running against this database."""
objects = InstanceManager() objects = InstanceManager()
@@ -87,6 +108,8 @@ class Instance(BaseModel):
class Meta: class Meta:
app_label = 'main' app_label = 'main'
POLICY_FIELDS = frozenset(('managed_by_policy', 'hostname', 'capacity_adjustment'))
def get_absolute_url(self, request=None): def get_absolute_url(self, request=None):
return reverse('api:instance_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:instance_detail', kwargs={'pk': self.pk}, request=request)
@@ -144,7 +167,7 @@ class Instance(BaseModel):
class InstanceGroup(BaseModel, RelatedJobsMixin): class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin):
"""A model representing a Queue/Group of AWX Instances.""" """A model representing a Queue/Group of AWX Instances."""
objects = InstanceGroupManager() objects = InstanceGroupManager()
@@ -179,6 +202,10 @@ class InstanceGroup(BaseModel, RelatedJobsMixin):
help_text=_("List of exact-match Instances that will always be automatically assigned to this group") help_text=_("List of exact-match Instances that will always be automatically assigned to this group")
) )
POLICY_FIELDS = frozenset((
'policy_instance_list', 'policy_instance_minimum', 'policy_instance_percentage', 'controller'
))
def get_absolute_url(self, request=None): def get_absolute_url(self, request=None):
return reverse('api:instance_group_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:instance_group_detail', kwargs={'pk': self.pk}, request=request)
@@ -259,29 +286,31 @@ class JobOrigin(models.Model):
app_label = 'main' app_label = 'main'
@receiver(post_save, sender=InstanceGroup) def schedule_policy_task():
def on_instance_group_saved(sender, instance, created=False, raw=False, **kwargs):
from awx.main.tasks import apply_cluster_membership_policies from awx.main.tasks import apply_cluster_membership_policies
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async()) connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
@receiver(post_save, sender=InstanceGroup)
def on_instance_group_saved(sender, instance, created=False, raw=False, **kwargs):
if created or instance.has_policy_changes():
schedule_policy_task()
@receiver(post_save, sender=Instance) @receiver(post_save, sender=Instance)
def on_instance_saved(sender, instance, created=False, raw=False, **kwargs): def on_instance_saved(sender, instance, created=False, raw=False, **kwargs):
if created: if created or instance.has_policy_changes():
from awx.main.tasks import apply_cluster_membership_policies schedule_policy_task()
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
@receiver(post_delete, sender=InstanceGroup) @receiver(post_delete, sender=InstanceGroup)
def on_instance_group_deleted(sender, instance, using, **kwargs): def on_instance_group_deleted(sender, instance, using, **kwargs):
from awx.main.tasks import apply_cluster_membership_policies schedule_policy_task()
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
@receiver(post_delete, sender=Instance) @receiver(post_delete, sender=Instance)
def on_instance_deleted(sender, instance, using, **kwargs): def on_instance_deleted(sender, instance, using, **kwargs):
from awx.main.tasks import apply_cluster_membership_policies schedule_policy_task()
connection.on_commit(lambda: apply_cluster_membership_policies.apply_async())
# Unfortunately, the signal can't just be connected against UnifiedJob; it # Unfortunately, the signal can't just be connected against UnifiedJob; it

View File

@@ -1,15 +1,66 @@
import pytest import pytest
import mock
from awx.main.models import AdHocCommand, InventoryUpdate, Job, JobTemplate, ProjectUpdate, Instance from awx.main.models import AdHocCommand, InventoryUpdate, Job, JobTemplate, ProjectUpdate
from awx.main.models.ha import Instance, InstanceGroup
from awx.main.tasks import apply_cluster_membership_policies from awx.main.tasks import apply_cluster_membership_policies
from awx.api.versioning import reverse from awx.api.versioning import reverse
from django.utils.timezone import now
@pytest.mark.django_db @pytest.mark.django_db
def test_default_tower_instance_group(default_instance_group, job_factory): def test_default_tower_instance_group(default_instance_group, job_factory):
assert default_instance_group in job_factory().preferred_instance_groups assert default_instance_group in job_factory().preferred_instance_groups
@pytest.mark.django_db
class TestPolicyTaskScheduling:
"""Tests make assertions about when the policy task gets scheduled"""
@pytest.mark.parametrize('field, value, expect', [
('name', 'foo-bar-foo-bar', False),
('policy_instance_percentage', 35, True),
('policy_instance_minimum', 3, True),
('policy_instance_list', ['bar?'], True),
('modified', now(), False)
])
def test_policy_task_ran_for_ig_when_needed(self, instance_group_factory, field, value, expect):
# always run on instance group creation
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
ig = InstanceGroup.objects.create(name='foo')
mock_policy.assert_called_once()
# selectively run on instance group modification
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
setattr(ig, field, value)
ig.save()
if expect:
mock_policy.assert_called_once()
else:
mock_policy.assert_not_called()
@pytest.mark.parametrize('field, value, expect', [
('hostname', 'foo-bar-foo-bar', True),
('managed_by_policy', False, True),
('enabled', False, False),
('capacity_adjustment', 0.42, True),
('capacity', 42, False)
])
def test_policy_task_ran_for_instance_when_needed(self, instance_group_factory, field, value, expect):
# always run on instance group creation
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
inst = Instance.objects.create(hostname='foo')
mock_policy.assert_called_once()
# selectively run on instance group modification
with mock.patch('awx.main.models.ha.schedule_policy_task') as mock_policy:
setattr(inst, field, value)
inst.save()
if expect:
mock_policy.assert_called_once()
else:
mock_policy.assert_not_called()
@pytest.mark.django_db @pytest.mark.django_db
def test_instance_dup(org_admin, organization, project, instance_factory, instance_group_factory, get, system_auditor): def test_instance_dup(org_admin, organization, project, instance_factory, instance_group_factory, get, system_auditor):
i1 = instance_factory("i1") i1 = instance_factory("i1")
@@ -167,18 +218,23 @@ def test_policy_instance_list_manually_assigned(instance_factory, instance_group
def test_policy_instance_list_explicitly_pinned(instance_factory, instance_group_factory): def test_policy_instance_list_explicitly_pinned(instance_factory, instance_group_factory):
i1 = instance_factory("i1") i1 = instance_factory("i1")
i2 = instance_factory("i2") i2 = instance_factory("i2")
i2.managed_by_policy = False
i2.save()
ig_1 = instance_group_factory("ig1", percentage=100, minimum=2) ig_1 = instance_group_factory("ig1", percentage=100, minimum=2)
ig_2 = instance_group_factory("ig2") ig_2 = instance_group_factory("ig2")
ig_2.policy_instance_list = [i2.hostname] ig_2.policy_instance_list = [i2.hostname]
ig_2.save() ig_2.save()
# without being marked as manual, i2 will be picked up by ig_1
apply_cluster_membership_policies() apply_cluster_membership_policies()
assert len(ig_1.instances.all()) == 1 assert set(ig_1.instances.all()) == set([i1, i2])
assert i1 in ig_1.instances.all() assert set(ig_2.instances.all()) == set([i2])
assert i2 not in ig_1.instances.all()
assert len(ig_2.instances.all()) == 1 i2.managed_by_policy = False
assert i2 in ig_2.instances.all() i2.save()
# after marking as manual, i2 no longer available for ig_1
apply_cluster_membership_policies()
assert set(ig_1.instances.all()) == set([i1])
assert set(ig_2.instances.all()) == set([i2])
@pytest.mark.django_db @pytest.mark.django_db