From 967624c5765cc963d6d065ef8baab8c92e040b3b Mon Sep 17 00:00:00 2001 From: AlanCoding Date: Thu, 17 May 2018 14:41:19 -0400 Subject: [PATCH] fix schedules modified_by getting nulled --- awx/main/models/base.py | 40 ++++++++++++++++--- awx/main/models/projects.py | 22 +++------- .../tests/functional/models/test_project.py | 8 ++++ .../tests/functional/models/test_schedule.py | 18 +++++++++ 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/awx/main/models/base.py b/awx/main/models/base.py index fcca82474c..709e603967 100644 --- a/awx/main/models/base.py +++ b/awx/main/models/base.py @@ -254,9 +254,13 @@ class PrimordialModel(CreatedModifiedModel): tags = TaggableManager(blank=True) + def __init__(self, *args, **kwargs): + r = super(PrimordialModel, self).__init__(*args, **kwargs) + self._prior_values_store = self._get_fields_snapshot() + return r + def save(self, *args, **kwargs): update_fields = kwargs.get('update_fields', []) - fields_are_specified = bool(update_fields) user = get_current_user() if user and not user.id: user = None @@ -264,15 +268,39 @@ class PrimordialModel(CreatedModifiedModel): self.created_by = user if 'created_by' not in update_fields: update_fields.append('created_by') - # Update modified_by if not called with update_fields, or if any - # editable fields are present in update_fields - if ( - (not fields_are_specified) or - any(getattr(self._meta.get_field(name), 'editable', True) for name in update_fields)): + # Update modified_by if any editable fields have changed + new_values = self._get_fields_snapshot() + if (not self.pk and not self.modified_by) or self.has_user_edits(new_values): self.modified_by = user if 'modified_by' not in update_fields: update_fields.append('modified_by') super(PrimordialModel, self).save(*args, **kwargs) + self._prior_values_store = new_values + + def has_user_edits(self, new_values): + return any( + new_values.get(fd_name, None) != self._prior_values_store.get(fd_name, None) + for fd_name in new_values.keys() + ) + + @classmethod + def _get_editable_fields(cls): + fds = set([]) + for field in cls._meta.concrete_fields: + if hasattr(field, 'attname'): + if field.attname == 'id': + continue + if getattr(field, 'editable', True): + fds.add(field.attname) + return fds + + def _get_fields_snapshot(self): + new_values = {} + editable_set = self._get_editable_fields() + for attr, val in self.__dict__.items(): + if attr in editable_set: + new_values[attr] = val + return new_values def clean_description(self): # Description should always be empty string, never null. diff --git a/awx/main/models/projects.py b/awx/main/models/projects.py index 3bad19c8eb..dce47eb8dd 100644 --- a/awx/main/models/projects.py +++ b/awx/main/models/projects.py @@ -324,13 +324,9 @@ class Project(UnifiedJobTemplate, ProjectOptions, ResourceMixin, CustomVirtualEn ['name', 'description', 'schedule'] ) - def __init__(self, *args, **kwargs): - r = super(Project, self).__init__(*args, **kwargs) - self._prior_values_store = self._current_sensitive_fields() - return r - def save(self, *args, **kwargs): new_instance = not bool(self.pk) + pre_save_vals = getattr(self, '_prior_values_store', {}) # If update_fields has been specified, add our field names to it, # if it hasn't been specified, then we're just doing a normal save. update_fields = kwargs.get('update_fields', []) @@ -361,21 +357,13 @@ class Project(UnifiedJobTemplate, ProjectOptions, ResourceMixin, CustomVirtualEn self.save(update_fields=update_fields) # If we just created a new project with SCM, start the initial update. # also update if certain fields have changed - relevant_change = False - new_values = self._current_sensitive_fields() - if hasattr(self, '_prior_values_store') and self._prior_values_store != new_values: - relevant_change = True - self._prior_values_store = new_values + relevant_change = any( + pre_save_vals.get(fd_name, None) != self._prior_values_store.get(fd_name, None) + for fd_name in self.FIELDS_TRIGGER_UPDATE + ) if (relevant_change or new_instance) and (not skip_update) and self.scm_type: self.update() - def _current_sensitive_fields(self): - new_values = {} - for attr, val in self.__dict__.items(): - if attr in Project.FIELDS_TRIGGER_UPDATE: - new_values[attr] = val - return new_values - def _get_current_status(self): if self.scm_type: if self.current_job and self.current_job.status: diff --git a/awx/main/tests/functional/models/test_project.py b/awx/main/tests/functional/models/test_project.py index 71352ed633..f150dbe00a 100644 --- a/awx/main/tests/functional/models/test_project.py +++ b/awx/main/tests/functional/models/test_project.py @@ -2,6 +2,7 @@ import pytest import mock from awx.main.models import Project +from awx.main.models.organization import Organization @pytest.mark.django_db @@ -31,3 +32,10 @@ def test_sensitive_change_triggers_update(project): project.scm_url = 'https://foo2.invalid' project.save() mock_update.assert_called_once_with() + + +@pytest.mark.django_db +def test_foreign_key_change_changes_modified_by(project, organization): + assert project._get_fields_snapshot()['organization_id'] == organization.id + project.organization = Organization(name='foo', pk=41) + assert project._get_fields_snapshot()['organization_id'] == 41 diff --git a/awx/main/tests/functional/models/test_schedule.py b/awx/main/tests/functional/models/test_schedule.py index d18e848d97..0921971c47 100644 --- a/awx/main/tests/functional/models/test_schedule.py +++ b/awx/main/tests/functional/models/test_schedule.py @@ -7,6 +7,8 @@ import pytz from awx.main.models import JobTemplate, Schedule +from crum import impersonate + @pytest.fixture def job_template(inventory, project): @@ -18,6 +20,22 @@ def job_template(inventory, project): ) +@pytest.mark.django_db +def test_computed_fields_modified_by_retained(job_template, admin_user): + with impersonate(admin_user): + s = Schedule.objects.create( + name='Some Schedule', + rrule='DTSTART:20300112T210000Z RRULE:FREQ=DAILY;INTERVAL=1', + unified_job_template=job_template + ) + s.refresh_from_db() + assert s.created_by == admin_user + assert s.modified_by == admin_user + s.update_computed_fields() + s.save() + assert s.modified_by == admin_user + + @pytest.mark.django_db def test_repeats_forever(job_template): s = Schedule(