Compare commits

..

1 Commits

Author SHA1 Message Date
Peter Braun
293abc8b35 fix: make indirect host counting live test more reliable 2025-03-04 23:00:50 +01:00
151 changed files with 1074 additions and 5674 deletions

View File

@@ -19,8 +19,6 @@ exclude_also =
branch = True branch = True
omit = omit =
awx/main/migrations/* awx/main/migrations/*
awx/settings/defaults.py
awx/settings/*_defaults.py
source = source =
. .
source_pkgs = source_pkgs =

View File

@@ -4,8 +4,7 @@
<!--- <!---
If you are fixing an existing issue, please include "related #nnn" in your If you are fixing an existing issue, please include "related #nnn" in your
commit message and your description; but you should still explain what commit message and your description; but you should still explain what
the change does. Also please make sure that if this PR has an attached JIRA, put AAP-<number> the change does.
in as the first entry for your PR title.
--> -->
##### ISSUE TYPE ##### ISSUE TYPE
@@ -23,6 +22,11 @@ in as the first entry for your PR title.
- Docs - Docs
- Other - Other
##### AWX VERSION
<!--- Paste verbatim output from `make VERSION` between quotes below -->
```
```
##### ADDITIONAL INFORMATION ##### ADDITIONAL INFORMATION

View File

@@ -11,7 +11,9 @@ inputs:
runs: runs:
using: composite using: composite
steps: steps:
- uses: ./.github/actions/setup-python - name: Get python version from Makefile
shell: bash
run: echo py_version=`make PYTHON_VERSION` >> $GITHUB_ENV
- name: Set lower case owner name - name: Set lower case owner name
shell: bash shell: bash
@@ -24,9 +26,26 @@ runs:
run: | run: |
echo "${{ inputs.github-token }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin echo "${{ inputs.github-token }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin
- uses: ./.github/actions/setup-ssh-agent - name: Generate placeholder SSH private key if SSH auth for private repos is not needed
id: generate_key
shell: bash
run: |
if [[ -z "${{ inputs.private-github-key }}" ]]; then
ssh-keygen -t ed25519 -C "github-actions" -N "" -f ~/.ssh/id_ed25519
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
cat ~/.ssh/id_ed25519 >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
else
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
echo "${{ inputs.private-github-key }}" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
fi
- name: Add private GitHub key to SSH agent
uses: webfactory/ssh-agent@v0.9.0
with: with:
ssh-private-key: ${{ inputs.private-github-key }} ssh-private-key: ${{ steps.generate_key.outputs.SSH_PRIVATE_KEY }}
- name: Pre-pull latest devel image to warm cache - name: Pre-pull latest devel image to warm cache
shell: bash shell: bash

View File

@@ -1,27 +0,0 @@
name: 'Setup Python from Makefile'
description: 'Extract and set up Python version from Makefile'
inputs:
python-version:
description: 'Override Python version (optional)'
required: false
default: ''
working-directory:
description: 'Directory containing the Makefile'
required: false
default: '.'
runs:
using: composite
steps:
- name: Get python version from Makefile
shell: bash
run: |
if [ -n "${{ inputs.python-version }}" ]; then
echo "py_version=${{ inputs.python-version }}" >> $GITHUB_ENV
else
cd ${{ inputs.working-directory }}
echo "py_version=`make PYTHON_VERSION`" >> $GITHUB_ENV
fi
- name: Install python
uses: actions/setup-python@v5
with:
python-version: ${{ env.py_version }}

View File

@@ -1,29 +0,0 @@
name: 'Setup SSH for GitHub'
description: 'Configure SSH for private repository access'
inputs:
ssh-private-key:
description: 'SSH private key for repository access'
required: false
default: ''
runs:
using: composite
steps:
- name: Generate placeholder SSH private key if SSH auth for private repos is not needed
id: generate_key
shell: bash
run: |
if [[ -z "${{ inputs.ssh-private-key }}" ]]; then
ssh-keygen -t ed25519 -C "github-actions" -N "" -f ~/.ssh/id_ed25519
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
cat ~/.ssh/id_ed25519 >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
else
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
echo "${{ inputs.ssh-private-key }}" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
fi
- name: Add private GitHub key to SSH agent
uses: webfactory/ssh-agent@v0.9.0
with:
ssh-private-key: ${{ steps.generate_key.outputs.SSH_PRIVATE_KEY }}

View File

@@ -130,7 +130,7 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'
@@ -161,10 +161,6 @@ jobs:
show-progress: false show-progress: false
path: awx path: awx
- uses: ./awx/.github/actions/setup-ssh-agent
with:
ssh-private-key: ${{ secrets.PRIVATE_GITHUB_KEY }}
- name: Checkout awx-operator - name: Checkout awx-operator
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
@@ -172,14 +168,39 @@ jobs:
repository: ansible/awx-operator repository: ansible/awx-operator
path: awx-operator path: awx-operator
- uses: ./awx/.github/actions/setup-python - name: Get python version from Makefile
working-directory: awx
run: echo py_version=`make PYTHON_VERSION` >> $GITHUB_ENV
- name: Install python ${{ env.py_version }}
uses: actions/setup-python@v4
with: with:
working-directory: awx python-version: ${{ env.py_version }}
- name: Install playbook dependencies - name: Install playbook dependencies
run: | run: |
python3 -m pip install docker python3 -m pip install docker
- name: Generate placeholder SSH private key if SSH auth for private repos is not needed
id: generate_key
shell: bash
run: |
if [[ -z "${{ secrets.PRIVATE_GITHUB_KEY }}" ]]; then
ssh-keygen -t ed25519 -C "github-actions" -N "" -f ~/.ssh/id_ed25519
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
cat ~/.ssh/id_ed25519 >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
else
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
echo "${{ secrets.PRIVATE_GITHUB_KEY }}" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
fi
- name: Add private GitHub key to SSH agent
uses: webfactory/ssh-agent@v0.9.0
with:
ssh-private-key: ${{ steps.generate_key.outputs.SSH_PRIVATE_KEY }}
- name: Build AWX image - name: Build AWX image
working-directory: awx working-directory: awx
run: | run: |
@@ -278,7 +299,7 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'
@@ -354,7 +375,7 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'

View File

@@ -49,10 +49,14 @@ jobs:
run: | run: |
echo "DEV_DOCKER_TAG_BASE=ghcr.io/${OWNER,,}" >> $GITHUB_ENV echo "DEV_DOCKER_TAG_BASE=ghcr.io/${OWNER,,}" >> $GITHUB_ENV
echo "COMPOSE_TAG=${GITHUB_REF##*/}" >> $GITHUB_ENV echo "COMPOSE_TAG=${GITHUB_REF##*/}" >> $GITHUB_ENV
echo py_version=`make PYTHON_VERSION` >> $GITHUB_ENV
env: env:
OWNER: '${{ github.repository_owner }}' OWNER: '${{ github.repository_owner }}'
- uses: ./.github/actions/setup-python - name: Install python ${{ env.py_version }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.py_version }}
- name: Log in to registry - name: Log in to registry
run: | run: |
@@ -69,9 +73,25 @@ jobs:
make ui make ui
if: matrix.build-targets.image-name == 'awx' if: matrix.build-targets.image-name == 'awx'
- uses: ./.github/actions/setup-ssh-agent - name: Generate placeholder SSH private key if SSH auth for private repos is not needed
id: generate_key
shell: bash
run: |
if [[ -z "${{ secrets.PRIVATE_GITHUB_KEY }}" ]]; then
ssh-keygen -t ed25519 -C "github-actions" -N "" -f ~/.ssh/id_ed25519
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
cat ~/.ssh/id_ed25519 >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
else
echo "SSH_PRIVATE_KEY<<EOF" >> $GITHUB_OUTPUT
echo "${{ secrets.PRIVATE_GITHUB_KEY }}" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
fi
- name: Add private GitHub key to SSH agent
uses: webfactory/ssh-agent@v0.9.0
with: with:
ssh-private-key: ${{ secrets.PRIVATE_GITHUB_KEY }} ssh-private-key: ${{ steps.generate_key.outputs.SSH_PRIVATE_KEY }}
- name: Build and push AWX devel images - name: Build and push AWX devel images
run: | run: |

View File

@@ -12,7 +12,7 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'

View File

@@ -34,11 +34,9 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - uses: actions/setup-python@v4
- name: Install python requests - name: Install python requests
run: pip install requests run: pip install requests
- name: Check if user is a member of Ansible org - name: Check if user is a member of Ansible org
uses: jannekem/run-python-script-action@v1 uses: jannekem/run-python-script-action@v1
id: check_user id: check_user

View File

@@ -33,7 +33,7 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'

View File

@@ -36,7 +36,13 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - name: Get python version from Makefile
run: echo py_version=`make PYTHON_VERSION` >> $GITHUB_ENV
- name: Install python ${{ env.py_version }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.py_version }}
- name: Install dependencies - name: Install dependencies
run: | run: |

View File

@@ -64,9 +64,14 @@ jobs:
repository: ansible/awx-logos repository: ansible/awx-logos
path: awx-logos path: awx-logos
- uses: ./awx/.github/actions/setup-python - name: Get python version from Makefile
working-directory: awx
run: echo py_version=`make PYTHON_VERSION` >> $GITHUB_ENV
- name: Install python ${{ env.py_version }}
uses: actions/setup-python@v4
with: with:
working-directory: awx python-version: ${{ env.py_version }}
- name: Install playbook dependencies - name: Install playbook dependencies
run: | run: |

View File

@@ -5,7 +5,6 @@ env:
LC_ALL: "C.UTF-8" # prevent ERROR: Ansible could not initialize the preferred locale: unsupported locale setting LC_ALL: "C.UTF-8" # prevent ERROR: Ansible could not initialize the preferred locale: unsupported locale setting
on: on:
workflow_dispatch:
push: push:
branches: branches:
- devel - devel
@@ -23,16 +22,18 @@ jobs:
with: with:
show-progress: false show-progress: false
- uses: ./.github/actions/setup-python - name: Get python version from Makefile
run: echo py_version=`make PYTHON_VERSION` >> $GITHUB_ENV
- name: Install python ${{ env.py_version }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.py_version }}
- name: Log in to registry - name: Log in to registry
run: | run: |
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin
- uses: ./.github/actions/setup-ssh-agent
with:
ssh-private-key: ${{ secrets.PRIVATE_GITHUB_KEY }}
- name: Pre-pull image to warm build cache - name: Pre-pull image to warm build cache
run: | run: |
docker pull -q ghcr.io/${{ github.repository_owner }}/awx_devel:${GITHUB_REF##*/} || : docker pull -q ghcr.io/${{ github.repository_owner }}/awx_devel:${GITHUB_REF##*/} || :
@@ -55,3 +56,5 @@ jobs:
ansible localhost -c local, -m command -a "{{ ansible_python_interpreter + ' -m pip install boto3'}}" ansible localhost -c local, -m command -a "{{ ansible_python_interpreter + ' -m pip install boto3'}}"
ansible localhost -c local -m aws_s3 \ ansible localhost -c local -m aws_s3 \
-a "src=${{ github.workspace }}/schema.json bucket=awx-public-ci-files object=${GITHUB_REF##*/}/schema.json mode=put permission=public-read" -a "src=${{ github.workspace }}/schema.json bucket=awx-public-ci-files object=${GITHUB_REF##*/}/schema.json mode=put permission=public-read"

2
.gitignore vendored
View File

@@ -150,8 +150,6 @@ use_dev_supervisor.txt
awx/ui/src awx/ui/src
awx/ui/build awx/ui/build
awx/ui/.ui-built
awx/ui_next
# Docs build stuff # Docs build stuff
docs/docsite/build/ docs/docsite/build/

View File

@@ -3,17 +3,6 @@
<img src="https://raw.githubusercontent.com/ansible/awx-logos/master/awx/ui/client/assets/logo-login.svg?sanitize=true" width=200 alt="AWX" /> <img src="https://raw.githubusercontent.com/ansible/awx-logos/master/awx/ui/client/assets/logo-login.svg?sanitize=true" width=200 alt="AWX" />
> [!CAUTION]
> The last release of this repository was released on Jul 2, 2024.
> **Releases of this project are now paused during a large scale refactoring.**
> For more information, follow [the Forum](https://forum.ansible.com/) and - more specifically - see the various communications on the matter:
>
> * [Blog: Upcoming Changes to the AWX Project](https://www.ansible.com/blog/upcoming-changes-to-the-awx-project/)
> * [Streamlining AWX Releases](https://forum.ansible.com/t/streamlining-awx-releases/6894) Primary update
> * [Refactoring AWX into a Pluggable, Service-Oriented Architecture](https://forum.ansible.com/t/refactoring-awx-into-a-pluggable-service-oriented-architecture/7404)
> * [Upcoming changes to AWX Operator installation methods](https://forum.ansible.com/t/upcoming-changes-to-awx-operator-installation-methods/7598)
> * [AWX UI and credential types transitioning to the new pluggable architecture](https://forum.ansible.com/t/awx-ui-and-credential-types-transitioning-to-the-new-pluggable-architecture/8027)
AWX provides a web-based user interface, REST API, and task engine built on top of [Ansible](https://github.com/ansible/ansible). It is one of the upstream projects for [Red Hat Ansible Automation Platform](https://www.ansible.com/products/automation-platform). AWX provides a web-based user interface, REST API, and task engine built on top of [Ansible](https://github.com/ansible/ansible). It is one of the upstream projects for [Red Hat Ansible Automation Platform](https://www.ansible.com/products/automation-platform).
To install AWX, please view the [Install guide](./INSTALL.md). To install AWX, please view the [Install guide](./INSTALL.md).

View File

@@ -62,8 +62,7 @@ else:
def prepare_env(): def prepare_env():
# Update the default settings environment variable based on current mode. # Update the default settings environment variable based on current mode.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings') os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings.%s' % MODE)
os.environ.setdefault('AWX_MODE', MODE)
# Hide DeprecationWarnings when running in production. Need to first load # Hide DeprecationWarnings when running in production. Need to first load
# settings to apply our filter after Django's own warnings filter. # settings to apply our filter after Django's own warnings filter.
from django.conf import settings from django.conf import settings

View File

@@ -161,7 +161,7 @@ def get_view_description(view, html=False):
def get_default_schema(): def get_default_schema():
if settings.DYNACONF.is_development_mode: if settings.SETTINGS_MODULE == 'awx.settings.development':
from awx.api.swagger import schema_view from awx.api.swagger import schema_view
return schema_view return schema_view

View File

@@ -6,8 +6,6 @@ import copy
import json import json
import logging import logging
import re import re
import yaml
import urllib.parse
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from datetime import timedelta from datetime import timedelta
from uuid import uuid4 from uuid import uuid4
@@ -46,9 +44,6 @@ from ansible_base.lib.utils.models import get_type_for_model
from ansible_base.rbac.models import RoleEvaluation, ObjectRole from ansible_base.rbac.models import RoleEvaluation, ObjectRole
from ansible_base.rbac import permission_registry from ansible_base.rbac import permission_registry
# django-flags
from flags.state import flag_enabled
# AWX # AWX
from awx.main.access import get_user_capabilities from awx.main.access import get_user_capabilities
from awx.main.constants import ACTIVE_STATES, org_role_to_permission from awx.main.constants import ACTIVE_STATES, org_role_to_permission
@@ -120,7 +115,6 @@ from awx.main.utils import (
from awx.main.utils.filters import SmartFilter from awx.main.utils.filters import SmartFilter
from awx.main.utils.plugins import load_combined_inventory_source_options from awx.main.utils.plugins import load_combined_inventory_source_options
from awx.main.utils.named_url_graph import reset_counters from awx.main.utils.named_url_graph import reset_counters
from awx.main.utils.inventory_vars import update_group_variables
from awx.main.scheduler.task_manager_models import TaskManagerModels from awx.main.scheduler.task_manager_models import TaskManagerModels
from awx.main.redact import UriCleaner, REPLACE_STR from awx.main.redact import UriCleaner, REPLACE_STR
from awx.main.signals import update_inventory_computed_fields from awx.main.signals import update_inventory_computed_fields
@@ -632,41 +626,15 @@ class BaseSerializer(serializers.ModelSerializer, metaclass=BaseSerializerMetacl
return exclusions return exclusions
def validate(self, attrs): def validate(self, attrs):
"""
Apply serializer validation. Called by DRF.
Can be extended by subclasses. Or consider overwriting
`validate_with_obj` in subclasses, which provides access to the model
object and exception handling for field validation.
:param dict attrs: The names and values of the model form fields.
:raise rest_framework.exceptions.ValidationError: If the validation
fails.
The exception must contain a dict with the names of the form fields
which failed validation as keys, and a list of error messages as
values. This ensures that the error messages are rendered near the
relevant fields.
:return: The names and values from the model form fields, possibly
modified by the validations.
:rtype: dict
"""
attrs = super(BaseSerializer, self).validate(attrs) attrs = super(BaseSerializer, self).validate(attrs)
# Create/update a model instance and run its full_clean() method to
# do any validation implemented on the model class.
exclusions = self.get_validation_exclusions(self.instance)
# Create a new model instance or take the existing one if it exists,
# and update its attributes with the respective field values from
# attrs.
obj = self.instance or self.Meta.model()
for k, v in attrs.items():
if k not in exclusions and k != 'canonical_address_port':
setattr(obj, k, v)
try: try:
# Run serializer validators which need the model object for # Create/update a model instance and run its full_clean() method to
# validation. # do any validation implemented on the model class.
self.validate_with_obj(attrs, obj) exclusions = self.get_validation_exclusions(self.instance)
# Apply any validations implemented on the model class. obj = self.instance or self.Meta.model()
for k, v in attrs.items():
if k not in exclusions and k != 'canonical_address_port':
setattr(obj, k, v)
obj.full_clean(exclude=exclusions) obj.full_clean(exclude=exclusions)
# full_clean may modify values on the instance; copy those changes # full_clean may modify values on the instance; copy those changes
# back to attrs so they are saved. # back to attrs so they are saved.
@@ -695,32 +663,6 @@ class BaseSerializer(serializers.ModelSerializer, metaclass=BaseSerializerMetacl
raise ValidationError(d) raise ValidationError(d)
return attrs return attrs
def validate_with_obj(self, attrs, obj):
"""
Overwrite this if you need the model instance for your validation.
:param dict attrs: The names and values of the model form fields.
:param obj: An instance of the class's meta model.
If the serializer runs on a newly created object, obj contains only
the attrs from its serializer. If the serializer runs because an
object has been edited, obj is the existing model instance with all
attributes and values available.
:raise django.core.exceptionsValidationError: Raise this if your
validation fails.
To make the error appear at the respective form field, instantiate
the Exception with a dict containing the field name as key and the
error message as value.
Example: ``ValidationError({"password": "Not good enough!"})``
If the exception contains just a string, the message cannot be
related to a field and is rendered at the top of the model form.
:return: None
"""
return
def reverse(self, *args, **kwargs): def reverse(self, *args, **kwargs):
kwargs['request'] = self.context.get('request') kwargs['request'] = self.context.get('request')
return reverse(*args, **kwargs) return reverse(*args, **kwargs)
@@ -737,25 +679,7 @@ class EmptySerializer(serializers.Serializer):
pass pass
class OpaQueryPathEnabledMixin(serializers.Serializer): class UnifiedJobTemplateSerializer(BaseSerializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not flag_enabled("FEATURE_POLICY_AS_CODE_ENABLED") and 'opa_query_path' in self.fields:
self.fields.pop('opa_query_path')
def validate_opa_query_path(self, value):
# Decode the URL and re-encode it
decoded_value = urllib.parse.unquote(value)
re_encoded_value = urllib.parse.quote(decoded_value, safe='/')
if value != re_encoded_value:
raise serializers.ValidationError(_("The URL must be properly encoded."))
return value
class UnifiedJobTemplateSerializer(BaseSerializer, OpaQueryPathEnabledMixin):
# As a base serializer, the capabilities prefetch is not used directly, # As a base serializer, the capabilities prefetch is not used directly,
# instead they are derived from the Workflow Job Template Serializer and the Job Template Serializer, respectively. # instead they are derived from the Workflow Job Template Serializer and the Job Template Serializer, respectively.
capabilities_prefetch = [] capabilities_prefetch = []
@@ -1060,6 +984,7 @@ class UserSerializer(BaseSerializer):
return ret return ret
def validate_password(self, value): def validate_password(self, value):
django_validate_password(value)
if not self.instance and value in (None, ''): if not self.instance and value in (None, ''):
raise serializers.ValidationError(_('Password required for new User.')) raise serializers.ValidationError(_('Password required for new User.'))
@@ -1082,50 +1007,6 @@ class UserSerializer(BaseSerializer):
return value return value
def validate_with_obj(self, attrs, obj):
"""
Validate the password with the Django password validators
To enable the Django password validators, configure
`settings.AUTH_PASSWORD_VALIDATORS` as described in the [Django
docs](https://docs.djangoproject.com/en/5.1/topics/auth/passwords/#enabling-password-validation)
:param dict attrs: The User form field names and their values as a dict.
Example::
{
'username': 'TestUsername', 'first_name': 'FirstName',
'last_name': 'LastName', 'email': 'First.Last@my.org',
'is_superuser': False, 'is_system_auditor': False,
'password': 'secret123'
}
:param obj: The User model instance.
:raises django.core.exceptions.ValidationError: Raise this if at least
one Django password validator fails.
The exception contains a dict ``{"password": <error-message>``}
which indicates that the password field has failed validation, and
the reason for failure.
:return: None.
"""
# We must do this here instead of in `validate_password` bacause some
# django password validators need access to other model instance fields,
# e.g. ``username`` for the ``UserAttributeSimilarityValidator``.
password = attrs.get("password")
# Skip validation if no password has been entered. This may happen when
# an existing User is edited.
if password and password != '$encrypted$':
# Apply validators from settings.AUTH_PASSWORD_VALIDATORS. This may
# raise ValidationError.
#
# If the validation fails, re-raise the exception with adjusted
# content to make the error appear near the password field.
try:
django_validate_password(password, user=obj)
except DjangoValidationError as exc:
raise DjangoValidationError({"password": exc.messages})
def _update_password(self, obj, new_password): def _update_password(self, obj, new_password):
if new_password and new_password != '$encrypted$': if new_password and new_password != '$encrypted$':
obj.set_password(new_password) obj.set_password(new_password)
@@ -1188,12 +1069,12 @@ class UserActivityStreamSerializer(UserSerializer):
fields = ('*', '-is_system_auditor') fields = ('*', '-is_system_auditor')
class OrganizationSerializer(BaseSerializer, OpaQueryPathEnabledMixin): class OrganizationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete'] show_capabilities = ['edit', 'delete']
class Meta: class Meta:
model = Organization model = Organization
fields = ('*', 'max_hosts', 'custom_virtualenv', 'default_environment', 'opa_query_path') fields = ('*', 'max_hosts', 'custom_virtualenv', 'default_environment')
read_only_fields = ('*', 'custom_virtualenv') read_only_fields = ('*', 'custom_virtualenv')
def get_related(self, obj): def get_related(self, obj):
@@ -1547,7 +1428,7 @@ class LabelsListMixin(object):
return res return res
class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables, OpaQueryPathEnabledMixin): class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables):
show_capabilities = ['edit', 'delete', 'adhoc', 'copy'] show_capabilities = ['edit', 'delete', 'adhoc', 'copy']
capabilities_prefetch = ['admin', 'adhoc', {'copy': 'organization.inventory_admin'}] capabilities_prefetch = ['admin', 'adhoc', {'copy': 'organization.inventory_admin'}]
@@ -1568,7 +1449,6 @@ class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables, OpaQuery
'inventory_sources_with_failures', 'inventory_sources_with_failures',
'pending_deletion', 'pending_deletion',
'prevent_instance_group_fallback', 'prevent_instance_group_fallback',
'opa_query_path',
) )
def get_related(self, obj): def get_related(self, obj):
@@ -1638,68 +1518,8 @@ class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables, OpaQuery
if kind == 'smart' and not host_filter: if kind == 'smart' and not host_filter:
raise serializers.ValidationError({'host_filter': _('Smart inventories must specify host_filter')}) raise serializers.ValidationError({'host_filter': _('Smart inventories must specify host_filter')})
return super(InventorySerializer, self).validate(attrs) return super(InventorySerializer, self).validate(attrs)
@staticmethod
def _update_variables(variables, inventory_id):
"""
Update the inventory variables of the 'all'-group.
The variables field contains vars from the inventory dialog, hence
representing the "all"-group variables.
Since this is not an update from an inventory source, we update the
variables when the inventory details form is saved.
A user edit on the inventory variables is considered a reset of the
variables update history. Particularly if the user removes a variable by
editing the inventory variables field, the variable is not supposed to
reappear with a value from a previous inventory source update.
We achieve this by forcing `reset=True` on such an update.
As a side-effect, variables which have been set by source updates and
have survived a user-edit (i.e. they have not been deleted from the
variables field) will be assumed to originate from the user edit and are
thus no longer deleted from the inventory when they are removed from
their original source!
Note that we use the inventory source id -1 for user-edit updates
because a regular inventory source cannot have an id of -1 since
PostgreSQL assigns pk's starting from 1 (if this assumption doesn't hold
true, we have to assign another special value for invsrc_id).
:param str variables: The variables as plain text in yaml or json
format.
:param int inventory_id: The primary key of the related inventory
object.
"""
variables_dict = parse_yaml_or_json(variables, silent_failure=False)
logger.debug(f"InventorySerializer._update_variables: {inventory_id=} {variables_dict=}, {variables=}")
update_group_variables(
group_id=None, # `None` denotes the 'all' group (which doesn't have a pk).
newvars=variables_dict,
dbvars=None,
invsrc_id=-1,
inventory_id=inventory_id,
reset=True,
)
def create(self, validated_data):
"""Called when a new inventory has to be created."""
logger.debug(f"InventorySerializer.create({validated_data=}) >>>>")
obj = super().create(validated_data)
self._update_variables(validated_data.get("variables") or "", obj.id)
return obj
def update(self, obj, validated_data):
"""Called when an existing inventory is updated."""
logger.debug(f"InventorySerializer.update({validated_data=}) >>>>")
obj = super().update(obj, validated_data)
self._update_variables(validated_data.get("variables") or "", obj.id)
return obj
class ConstructedFieldMixin(serializers.Field): class ConstructedFieldMixin(serializers.Field):
def get_attribute(self, instance): def get_attribute(self, instance):
@@ -1989,12 +1809,10 @@ class GroupSerializer(BaseSerializerWithVariables):
return res return res
def validate(self, attrs): def validate(self, attrs):
# Do not allow the group name to conflict with an existing host name.
name = force_str(attrs.get('name', self.instance and self.instance.name or '')) name = force_str(attrs.get('name', self.instance and self.instance.name or ''))
inventory = attrs.get('inventory', self.instance and self.instance.inventory or '') inventory = attrs.get('inventory', self.instance and self.instance.inventory or '')
if Host.objects.filter(name=name, inventory=inventory).exists(): if Host.objects.filter(name=name, inventory=inventory).exists():
raise serializers.ValidationError(_('A Host with that name already exists.')) raise serializers.ValidationError(_('A Host with that name already exists.'))
#
return super(GroupSerializer, self).validate(attrs) return super(GroupSerializer, self).validate(attrs)
def validate_name(self, value): def validate_name(self, value):
@@ -3333,7 +3151,6 @@ class JobTemplateSerializer(JobTemplateMixin, UnifiedJobTemplateSerializer, JobO
'webhook_service', 'webhook_service',
'webhook_credential', 'webhook_credential',
'prevent_instance_group_fallback', 'prevent_instance_group_fallback',
'opa_query_path',
) )
read_only_fields = ('*', 'custom_virtualenv') read_only_fields = ('*', 'custom_virtualenv')
@@ -3535,17 +3352,11 @@ class JobRelaunchSerializer(BaseSerializer):
choices=[('all', _('No change to job limit')), ('failed', _('All failed and unreachable hosts'))], choices=[('all', _('No change to job limit')), ('failed', _('All failed and unreachable hosts'))],
write_only=True, write_only=True,
) )
job_type = serializers.ChoiceField(
required=False,
allow_null=True,
choices=NEW_JOB_TYPE_CHOICES,
write_only=True,
)
credential_passwords = VerbatimField(required=True, write_only=True) credential_passwords = VerbatimField(required=True, write_only=True)
class Meta: class Meta:
model = Job model = Job
fields = ('passwords_needed_to_start', 'retry_counts', 'hosts', 'job_type', 'credential_passwords') fields = ('passwords_needed_to_start', 'retry_counts', 'hosts', 'credential_passwords')
def validate_credential_passwords(self, value): def validate_credential_passwords(self, value):
pnts = self.instance.passwords_needed_to_start pnts = self.instance.passwords_needed_to_start
@@ -6004,34 +5815,6 @@ class InstanceGroupSerializer(BaseSerializer):
raise serializers.ValidationError(_('Only Kubernetes credentials can be associated with an Instance Group')) raise serializers.ValidationError(_('Only Kubernetes credentials can be associated with an Instance Group'))
return value return value
def validate_pod_spec_override(self, value):
if not value:
return value
# value should be empty for non-container groups
if self.instance and not self.instance.is_container_group:
raise serializers.ValidationError(_('pod_spec_override is only valid for container groups'))
pod_spec_override_json = None
# defect if the value is yaml or json if yaml convert to json
try:
# convert yaml to json
pod_spec_override_json = yaml.safe_load(value)
except yaml.YAMLError:
try:
pod_spec_override_json = json.loads(value)
except json.JSONDecodeError:
raise serializers.ValidationError(_('pod_spec_override must be valid yaml or json'))
# validate the
spec = pod_spec_override_json.get('spec', {})
automount_service_account_token = spec.get('automountServiceAccountToken', False)
if automount_service_account_token:
raise serializers.ValidationError(_('automountServiceAccountToken is not allowed for security reasons'))
return value
def validate(self, attrs): def validate(self, attrs):
attrs = super(InstanceGroupSerializer, self).validate(attrs) attrs = super(InstanceGroupSerializer, self).validate(attrs)

View File

@@ -3435,7 +3435,6 @@ class JobRelaunch(RetrieveAPIView):
copy_kwargs = {} copy_kwargs = {}
retry_hosts = serializer.validated_data.get('hosts', None) retry_hosts = serializer.validated_data.get('hosts', None)
job_type = serializer.validated_data.get('job_type', None)
if retry_hosts and retry_hosts != 'all': if retry_hosts and retry_hosts != 'all':
if obj.status in ACTIVE_STATES: if obj.status in ACTIVE_STATES:
return Response( return Response(
@@ -3456,8 +3455,6 @@ class JobRelaunch(RetrieveAPIView):
) )
copy_kwargs['limit'] = ','.join(retry_host_list) copy_kwargs['limit'] = ','.join(retry_host_list)
if job_type:
copy_kwargs['job_type'] = job_type
new_job = obj.copy_unified_job(**copy_kwargs) new_job = obj.copy_unified_job(**copy_kwargs)
result = new_job.signal_start(**serializer.validated_data['credential_passwords']) result = new_job.signal_start(**serializer.validated_data['credential_passwords'])
if not result: if not result:

View File

@@ -10,7 +10,7 @@ from awx.api.generics import APIView, Response
from awx.api.permissions import AnalyticsPermission from awx.api.permissions import AnalyticsPermission
from awx.api.versioning import reverse from awx.api.versioning import reverse
from awx.main.utils import get_awx_version from awx.main.utils import get_awx_version
from awx.main.utils.analytics_proxy import OIDCClient from awx.main.utils.analytics_proxy import OIDCClient, DEFAULT_OIDC_ENDPOINT
from rest_framework import status from rest_framework import status
from collections import OrderedDict from collections import OrderedDict
@@ -202,16 +202,10 @@ class AnalyticsGenericView(APIView):
if method not in ["GET", "POST", "OPTIONS"]: if method not in ["GET", "POST", "OPTIONS"]:
return self._error_response(ERROR_UNSUPPORTED_METHOD, method, remote=False, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return self._error_response(ERROR_UNSUPPORTED_METHOD, method, remote=False, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
url = self._get_analytics_url(request.path) url = self._get_analytics_url(request.path)
using_subscriptions_credentials = False
try: try:
rh_user = getattr(settings, 'REDHAT_USERNAME', None) rh_user = self._get_setting('REDHAT_USERNAME', None, ERROR_MISSING_USER)
rh_password = getattr(settings, 'REDHAT_PASSWORD', None) rh_password = self._get_setting('REDHAT_PASSWORD', None, ERROR_MISSING_PASSWORD)
if not (rh_user and rh_password): client = OIDCClient(rh_user, rh_password, DEFAULT_OIDC_ENDPOINT, ['api.console'])
rh_user = self._get_setting('SUBSCRIPTIONS_CLIENT_ID', None, ERROR_MISSING_USER)
rh_password = self._get_setting('SUBSCRIPTIONS_CLIENT_SECRET', None, ERROR_MISSING_PASSWORD)
using_subscriptions_credentials = True
client = OIDCClient(rh_user, rh_password)
response = client.make_request( response = client.make_request(
method, method,
url, url,
@@ -222,17 +216,17 @@ class AnalyticsGenericView(APIView):
timeout=(31, 31), timeout=(31, 31),
) )
except requests.RequestException: except requests.RequestException:
# subscriptions credentials are not valid for basic auth, so just return 401 logger.error("Automation Analytics API request failed, trying base auth method")
if using_subscriptions_credentials: response = self._base_auth_request(request, method, url, rh_user, rh_password, headers)
response = Response(status=status.HTTP_401_UNAUTHORIZED) except MissingSettings:
else: rh_user = self._get_setting('SUBSCRIPTIONS_USERNAME', None, ERROR_MISSING_USER)
logger.error("Automation Analytics API request failed, trying base auth method") rh_password = self._get_setting('SUBSCRIPTIONS_PASSWORD', None, ERROR_MISSING_PASSWORD)
response = self._base_auth_request(request, method, url, rh_user, rh_password, headers) response = self._base_auth_request(request, method, url, rh_user, rh_password, headers)
# #
# Missing or wrong user/pass # Missing or wrong user/pass
# #
if response.status_code == status.HTTP_401_UNAUTHORIZED: if response.status_code == status.HTTP_401_UNAUTHORIZED:
text = response.get('text', '').rstrip("\n") text = (response.text or '').rstrip("\n")
return self._error_response(ERROR_UNAUTHORIZED, text, remote=True, remote_status_code=response.status_code) return self._error_response(ERROR_UNAUTHORIZED, text, remote=True, remote_status_code=response.status_code)
# #
# Not found, No entitlement or No data in Analytics # Not found, No entitlement or No data in Analytics

View File

@@ -32,7 +32,6 @@ from awx.api.versioning import URLPathVersioning, reverse, drf_reverse
from awx.main.constants import PRIVILEGE_ESCALATION_METHODS from awx.main.constants import PRIVILEGE_ESCALATION_METHODS
from awx.main.models import Project, Organization, Instance, InstanceGroup, JobTemplate from awx.main.models import Project, Organization, Instance, InstanceGroup, JobTemplate
from awx.main.utils import set_environ from awx.main.utils import set_environ
from awx.main.utils.analytics_proxy import TokenError
from awx.main.utils.licensing import get_licenser from awx.main.utils.licensing import get_licenser
logger = logging.getLogger('awx.api.views.root') logger = logging.getLogger('awx.api.views.root')
@@ -177,21 +176,19 @@ class ApiV2SubscriptionView(APIView):
def post(self, request): def post(self, request):
data = request.data.copy() data = request.data.copy()
if data.get('subscriptions_client_secret') == '$encrypted$': if data.get('subscriptions_password') == '$encrypted$':
data['subscriptions_client_secret'] = settings.SUBSCRIPTIONS_CLIENT_SECRET data['subscriptions_password'] = settings.SUBSCRIPTIONS_PASSWORD
try: try:
user, pw = data.get('subscriptions_client_id'), data.get('subscriptions_client_secret') user, pw = data.get('subscriptions_username'), data.get('subscriptions_password')
with set_environ(**settings.AWX_TASK_ENV): with set_environ(**settings.AWX_TASK_ENV):
validated = get_licenser().validate_rh(user, pw) validated = get_licenser().validate_rh(user, pw)
if user: if user:
settings.SUBSCRIPTIONS_CLIENT_ID = data['subscriptions_client_id'] settings.SUBSCRIPTIONS_USERNAME = data['subscriptions_username']
if pw: if pw:
settings.SUBSCRIPTIONS_CLIENT_SECRET = data['subscriptions_client_secret'] settings.SUBSCRIPTIONS_PASSWORD = data['subscriptions_password']
except Exception as exc: except Exception as exc:
msg = _("Invalid Subscription") msg = _("Invalid Subscription")
if isinstance(exc, TokenError) or ( if isinstance(exc, requests.exceptions.HTTPError) and getattr(getattr(exc, 'response', None), 'status_code', None) == 401:
isinstance(exc, requests.exceptions.HTTPError) and getattr(getattr(exc, 'response', None), 'status_code', None) == 401
):
msg = _("The provided credentials are invalid (HTTP 401).") msg = _("The provided credentials are invalid (HTTP 401).")
elif isinstance(exc, requests.exceptions.ProxyError): elif isinstance(exc, requests.exceptions.ProxyError):
msg = _("Unable to connect to proxy server.") msg = _("Unable to connect to proxy server.")
@@ -218,12 +215,12 @@ class ApiV2AttachView(APIView):
def post(self, request): def post(self, request):
data = request.data.copy() data = request.data.copy()
subscription_id = data.get('subscription_id', None) pool_id = data.get('pool_id', None)
if not subscription_id: if not pool_id:
return Response({"error": _("No subscription ID provided.")}, status=status.HTTP_400_BAD_REQUEST) return Response({"error": _("No subscription pool ID provided.")}, status=status.HTTP_400_BAD_REQUEST)
user = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None) user = getattr(settings, 'SUBSCRIPTIONS_USERNAME', None)
pw = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None) pw = getattr(settings, 'SUBSCRIPTIONS_PASSWORD', None)
if subscription_id and user and pw: if pool_id and user and pw:
data = request.data.copy() data = request.data.copy()
try: try:
with set_environ(**settings.AWX_TASK_ENV): with set_environ(**settings.AWX_TASK_ENV):
@@ -242,7 +239,7 @@ class ApiV2AttachView(APIView):
logger.exception(smart_str(u"Invalid subscription submitted."), extra=dict(actor=request.user.username)) logger.exception(smart_str(u"Invalid subscription submitted."), extra=dict(actor=request.user.username))
return Response({"error": msg}, status=status.HTTP_400_BAD_REQUEST) return Response({"error": msg}, status=status.HTTP_400_BAD_REQUEST)
for sub in validated: for sub in validated:
if sub['subscription_id'] == subscription_id: if sub['pool_id'] == pool_id:
sub['valid_key'] = True sub['valid_key'] = True
settings.LICENSE = sub settings.LICENSE = sub
return Response(sub) return Response(sub)

View File

@@ -10,7 +10,7 @@ from django.core.validators import URLValidator, _lazy_re_compile
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
# Django REST Framework # Django REST Framework
from rest_framework.fields import BooleanField, CharField, ChoiceField, DictField, DateTimeField, EmailField, IntegerField, ListField, FloatField # noqa from rest_framework.fields import BooleanField, CharField, ChoiceField, DictField, DateTimeField, EmailField, IntegerField, ListField # noqa
from rest_framework.serializers import PrimaryKeyRelatedField # noqa from rest_framework.serializers import PrimaryKeyRelatedField # noqa
# AWX # AWX
@@ -207,8 +207,7 @@ class URLField(CharField):
if self.allow_plain_hostname: if self.allow_plain_hostname:
try: try:
url_parts = urlparse.urlsplit(value) url_parts = urlparse.urlsplit(value)
looks_like_ipv6 = bool(url_parts.netloc and url_parts.netloc.startswith('[') and url_parts.netloc.endswith(']')) if url_parts.hostname and '.' not in url_parts.hostname:
if not looks_like_ipv6 and url_parts.hostname and '.' not in url_parts.hostname:
netloc = '{}.local'.format(url_parts.hostname) netloc = '{}.local'.format(url_parts.hostname)
if url_parts.port: if url_parts.port:
netloc = '{}:{}'.format(netloc, url_parts.port) netloc = '{}:{}'.format(netloc, url_parts.port)

View File

@@ -27,5 +27,5 @@ def _migrate_setting(apps, old_key, new_key, encrypted=False):
def prefill_rh_credentials(apps, schema_editor): def prefill_rh_credentials(apps, schema_editor):
_migrate_setting(apps, 'REDHAT_USERNAME', 'SUBSCRIPTIONS_CLIENT_ID', encrypted=False) _migrate_setting(apps, 'REDHAT_USERNAME', 'SUBSCRIPTIONS_USERNAME', encrypted=False)
_migrate_setting(apps, 'REDHAT_PASSWORD', 'SUBSCRIPTIONS_CLIENT_SECRET', encrypted=True) _migrate_setting(apps, 'REDHAT_PASSWORD', 'SUBSCRIPTIONS_PASSWORD', encrypted=True)

View File

@@ -128,41 +128,3 @@ class TestURLField:
else: else:
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
field.run_validators(url) field.run_validators(url)
@pytest.mark.parametrize(
"url, expect_error",
[
("https://[1:2:3]", True),
("http://[1:2:3]", True),
("https://[2001:db8:3333:4444:5555:6666:7777:8888", True),
("https://2001:db8:3333:4444:5555:6666:7777:8888", True),
("https://[2001:db8:3333:4444:5555:6666:7777:8888]", False),
("https://[::1]", False),
("https://[::]", False),
("https://[2001:db8::1]", False),
("https://[2001:db8:0:0:0:0:1:1]", False),
("https://[fe80::2%eth0]", True), # ipv6 scope identifier
("https://[fe80:0:0:0:200:f8ff:fe21:67cf]", False),
("https://[::ffff:192.168.1.10]", False),
("https://[0:0:0:0:0:ffff:c000:0201]", False),
("https://[2001:0db8:000a:0001:0000:0000:0000:0000]", False),
("https://[2001:db8:a:1::]", False),
("https://[ff02::1]", False),
("https://[ff02:0:0:0:0:0:0:1]", False),
("https://[fc00::1]", False),
("https://[fd12:3456:789a:1::1]", False),
("https://[2001:db8::abcd:ef12:3456:7890]", False),
("https://[2001:db8:0000:abcd:0000:ef12:0000:3456]", False),
("https://[::ffff:10.0.0.1]", False),
("https://[2001:db8:cafe::]", False),
("https://[2001:db8:cafe:0:0:0:0:0]", False),
("https://[fe80::210:f3ff:fedf:4567%3]", True), # ipv6 scope identifier, numerical interface
],
)
def test_ipv6_urls(self, url, expect_error):
field = URLField()
if expect_error:
with pytest.raises(ValidationError, match="Enter a valid URL"):
field.run_validators(url)
else:
field.run_validators(url)

View File

@@ -2098,7 +2098,7 @@ class WorkflowJobAccess(BaseAccess):
def filtered_queryset(self): def filtered_queryset(self):
return WorkflowJob.objects.filter( return WorkflowJob.objects.filter(
Q(unified_job_template__in=UnifiedJobTemplate.accessible_pk_qs(self.user, 'read_role')) Q(unified_job_template__in=UnifiedJobTemplate.accessible_pk_qs(self.user, 'read_role'))
| Q(organization__in=Organization.accessible_pk_qs(self.user, 'auditor_role')) | Q(organization__in=Organization.objects.filter(Q(admin_role__members=self.user)), is_bulk_job=True)
) )
def can_read(self, obj): def can_read(self, obj):
@@ -2496,11 +2496,12 @@ class UnifiedJobAccess(BaseAccess):
def filtered_queryset(self): def filtered_queryset(self):
inv_pk_qs = Inventory._accessible_pk_qs(Inventory, self.user, 'read_role') inv_pk_qs = Inventory._accessible_pk_qs(Inventory, self.user, 'read_role')
org_auditor_qs = Organization.objects.filter(Q(admin_role__members=self.user) | Q(auditor_role__members=self.user))
qs = self.model.objects.filter( qs = self.model.objects.filter(
Q(unified_job_template_id__in=UnifiedJobTemplate.accessible_pk_qs(self.user, 'read_role')) Q(unified_job_template_id__in=UnifiedJobTemplate.accessible_pk_qs(self.user, 'read_role'))
| Q(inventoryupdate__inventory_source__inventory__id__in=inv_pk_qs) | Q(inventoryupdate__inventory_source__inventory__id__in=inv_pk_qs)
| Q(adhoccommand__inventory__id__in=inv_pk_qs) | Q(adhoccommand__inventory__id__in=inv_pk_qs)
| Q(organization__in=Organization.accessible_pk_qs(self.user, 'auditor_role')) | Q(organization__in=org_auditor_qs)
) )
return qs return qs

View File

@@ -22,7 +22,7 @@ from ansible_base.lib.utils.db import advisory_lock
from awx.main.models import Job from awx.main.models import Job
from awx.main.access import access_registry from awx.main.access import access_registry
from awx.main.utils import get_awx_http_client_headers, set_environ, datetime_hook from awx.main.utils import get_awx_http_client_headers, set_environ, datetime_hook
from awx.main.utils.analytics_proxy import OIDCClient from awx.main.utils.analytics_proxy import OIDCClient, DEFAULT_OIDC_ENDPOINT
__all__ = ['register', 'gather', 'ship'] __all__ = ['register', 'gather', 'ship']
@@ -186,7 +186,7 @@ def gather(dest=None, module=None, subset=None, since=None, until=None, collecti
if not ( if not (
settings.AUTOMATION_ANALYTICS_URL settings.AUTOMATION_ANALYTICS_URL
and ((settings.REDHAT_USERNAME and settings.REDHAT_PASSWORD) or (settings.SUBSCRIPTIONS_CLIENT_ID and settings.SUBSCRIPTIONS_CLIENT_SECRET)) and ((settings.REDHAT_USERNAME and settings.REDHAT_PASSWORD) or (settings.SUBSCRIPTIONS_USERNAME and settings.SUBSCRIPTIONS_PASSWORD))
): ):
logger.log(log_level, "Not gathering analytics, configuration is invalid. Use --dry-run to gather locally without sending.") logger.log(log_level, "Not gathering analytics, configuration is invalid. Use --dry-run to gather locally without sending.")
return None return None
@@ -368,20 +368,8 @@ def ship(path):
logger.error('AUTOMATION_ANALYTICS_URL is not set') logger.error('AUTOMATION_ANALYTICS_URL is not set')
return False return False
rh_id = getattr(settings, 'REDHAT_USERNAME', None) rh_user = getattr(settings, 'REDHAT_USERNAME', None)
rh_secret = getattr(settings, 'REDHAT_PASSWORD', None) rh_password = getattr(settings, 'REDHAT_PASSWORD', None)
if not (rh_id and rh_secret):
rh_id = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None)
rh_secret = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None)
if not rh_id:
logger.error('Neither REDHAT_USERNAME nor SUBSCRIPTIONS_CLIENT_ID are set')
return False
if not rh_secret:
logger.error('Neither REDHAT_PASSWORD nor SUBSCRIPTIONS_CLIENT_SECRET are set')
return False
with open(path, 'rb') as f: with open(path, 'rb') as f:
files = {'file': (os.path.basename(path), f, settings.INSIGHTS_AGENT_MIME)} files = {'file': (os.path.basename(path), f, settings.INSIGHTS_AGENT_MIME)}
@@ -389,13 +377,25 @@ def ship(path):
s.headers = get_awx_http_client_headers() s.headers = get_awx_http_client_headers()
s.headers.pop('Content-Type') s.headers.pop('Content-Type')
with set_environ(**settings.AWX_TASK_ENV): with set_environ(**settings.AWX_TASK_ENV):
try: if rh_user and rh_password:
client = OIDCClient(rh_id, rh_secret) try:
response = client.make_request("POST", url, headers=s.headers, files=files, verify=settings.INSIGHTS_CERT_PATH, timeout=(31, 31)) client = OIDCClient(rh_user, rh_password, DEFAULT_OIDC_ENDPOINT, ['api.console'])
except requests.RequestException: response = client.make_request("POST", url, headers=s.headers, files=files, verify=settings.INSIGHTS_CERT_PATH, timeout=(31, 31))
logger.error("Automation Analytics API request failed, trying base auth method") except requests.RequestException:
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_id, rh_secret), headers=s.headers, timeout=(31, 31)) logger.error("Automation Analytics API request failed, trying base auth method")
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_user, rh_password), headers=s.headers, timeout=(31, 31))
elif not rh_user or not rh_password:
logger.info('REDHAT_USERNAME and REDHAT_PASSWORD are not set, using SUBSCRIPTIONS_USERNAME and SUBSCRIPTIONS_PASSWORD')
rh_user = getattr(settings, 'SUBSCRIPTIONS_USERNAME', None)
rh_password = getattr(settings, 'SUBSCRIPTIONS_PASSWORD', None)
if rh_user and rh_password:
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_user, rh_password), headers=s.headers, timeout=(31, 31))
elif not rh_user:
logger.error('REDHAT_USERNAME and SUBSCRIPTIONS_USERNAME are not set')
return False
elif not rh_password:
logger.error('REDHAT_PASSWORD and SUBSCRIPTIONS_USERNAME are not set')
return False
# Accept 2XX status_codes # Accept 2XX status_codes
if response.status_code >= 300: if response.status_code >= 300:
logger.error('Upload failed with status {}, {}'.format(response.status_code, response.text)) logger.error('Upload failed with status {}, {}'.format(response.status_code, response.text))

View File

@@ -9,7 +9,6 @@ from prometheus_client.core import GaugeMetricFamily, HistogramMetricFamily
from prometheus_client.registry import CollectorRegistry from prometheus_client.registry import CollectorRegistry
from django.conf import settings from django.conf import settings
from django.http import HttpRequest from django.http import HttpRequest
import redis.exceptions
from rest_framework.request import Request from rest_framework.request import Request
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
@@ -291,12 +290,8 @@ class Metrics(MetricsNamespace):
def send_metrics(self): def send_metrics(self):
# more than one thread could be calling this at the same time, so should # more than one thread could be calling this at the same time, so should
# acquire redis lock before sending metrics # acquire redis lock before sending metrics
try: lock = self.conn.lock(root_key + '-' + self._namespace + '_lock')
lock = self.conn.lock(root_key + '-' + self._namespace + '_lock') if not lock.acquire(blocking=False):
if not lock.acquire(blocking=False):
return
except redis.exceptions.ConnectionError as exc:
logger.warning(f'Connection error in send_metrics: {exc}')
return return
try: try:
current_time = time.time() current_time = time.time()

View File

@@ -4,7 +4,6 @@ import logging
# Django # Django
from django.core.checks import Error from django.core.checks import Error
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.conf import settings
# Django REST Framework # Django REST Framework
from rest_framework import serializers from rest_framework import serializers
@@ -13,7 +12,6 @@ from rest_framework import serializers
from awx.conf import fields, register, register_validate from awx.conf import fields, register, register_validate
from awx.main.models import ExecutionEnvironment from awx.main.models import ExecutionEnvironment
from awx.main.constants import SUBSCRIPTION_USAGE_MODEL_UNIQUE_HOSTS from awx.main.constants import SUBSCRIPTION_USAGE_MODEL_UNIQUE_HOSTS
from awx.main.tasks.policy import OPA_AUTH_TYPES
logger = logging.getLogger('awx.main.conf') logger = logging.getLogger('awx.main.conf')
@@ -126,8 +124,8 @@ register(
allow_blank=True, allow_blank=True,
encrypted=False, encrypted=False,
read_only=False, read_only=False,
label=_('Red Hat Client ID for Analytics'), label=_('Red Hat customer username'),
help_text=_('Client ID used to send data to Automation Analytics'), help_text=_('This username is used to send data to Automation Analytics'),
category=_('System'), category=_('System'),
category_slug='system', category_slug='system',
) )
@@ -139,34 +137,34 @@ register(
allow_blank=True, allow_blank=True,
encrypted=True, encrypted=True,
read_only=False, read_only=False,
label=_('Red Hat Client Secret for Analytics'), label=_('Red Hat customer password'),
help_text=_('Client secret used to send data to Automation Analytics'), help_text=_('This password is used to send data to Automation Analytics'),
category=_('System'), category=_('System'),
category_slug='system', category_slug='system',
) )
register( register(
'SUBSCRIPTIONS_CLIENT_ID', 'SUBSCRIPTIONS_USERNAME',
field_class=fields.CharField, field_class=fields.CharField,
default='', default='',
allow_blank=True, allow_blank=True,
encrypted=False, encrypted=False,
read_only=False, read_only=False,
label=_('Red Hat Client ID for Subscriptions'), label=_('Red Hat or Satellite username'),
help_text=_('Client ID used to retrieve subscription and content information'), # noqa help_text=_('This username is used to retrieve subscription and content information'), # noqa
category=_('System'), category=_('System'),
category_slug='system', category_slug='system',
) )
register( register(
'SUBSCRIPTIONS_CLIENT_SECRET', 'SUBSCRIPTIONS_PASSWORD',
field_class=fields.CharField, field_class=fields.CharField,
default='', default='',
allow_blank=True, allow_blank=True,
encrypted=True, encrypted=True,
read_only=False, read_only=False,
label=_('Red Hat Client Secret for Subscriptions'), label=_('Red Hat or Satellite password'),
help_text=_('Client secret used to retrieve subscription and content information'), # noqa help_text=_('This password is used to retrieve subscription and content information'), # noqa
category=_('System'), category=_('System'),
category_slug='system', category_slug='system',
) )
@@ -982,125 +980,3 @@ def csrf_trusted_origins_validate(serializer, attrs):
register_validate('system', csrf_trusted_origins_validate) register_validate('system', csrf_trusted_origins_validate)
if settings.FEATURE_POLICY_AS_CODE_ENABLED: # Unable to use flag_enabled due to AppRegistryNotReady error
register(
'OPA_HOST',
field_class=fields.CharField,
label=_('OPA server hostname'),
default='',
help_text=_('The hostname used to connect to the OPA server. If empty, policy enforcement will be disabled.'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
)
register(
'OPA_PORT',
field_class=fields.IntegerField,
label=_('OPA server port'),
default=8181,
help_text=_('The port used to connect to the OPA server. Defaults to 8181.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_SSL',
field_class=fields.BooleanField,
label=_('Use SSL for OPA connection'),
default=False,
help_text=_('Enable or disable the use of SSL to connect to the OPA server. Defaults to false.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_AUTH_TYPE',
field_class=fields.ChoiceField,
label=_('OPA authentication type'),
choices=[OPA_AUTH_TYPES.NONE, OPA_AUTH_TYPES.TOKEN, OPA_AUTH_TYPES.CERTIFICATE],
default=OPA_AUTH_TYPES.NONE,
help_text=_('The authentication type that will be used to connect to the OPA server: "None", "Token", or "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_AUTH_TOKEN',
field_class=fields.CharField,
label=_('OPA authentication token'),
default='',
help_text=_(
'The token for authentication to the OPA server. Required when OPA_AUTH_TYPE is "Token". If an authorization header is defined in OPA_AUTH_CUSTOM_HEADERS, it will be overridden by OPA_AUTH_TOKEN.'
),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
encrypted=True,
)
register(
'OPA_AUTH_CLIENT_CERT',
field_class=fields.CharField,
label=_('OPA client certificate content'),
default='',
help_text=_('The content of the client certificate file for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
)
register(
'OPA_AUTH_CLIENT_KEY',
field_class=fields.CharField,
label=_('OPA client key content'),
default='',
help_text=_('The content of the client key for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
encrypted=True,
)
register(
'OPA_AUTH_CA_CERT',
field_class=fields.CharField,
label=_('OPA CA certificate content'),
default='',
help_text=_('The content of the CA certificate for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
)
register(
'OPA_AUTH_CUSTOM_HEADERS',
field_class=fields.DictField,
label=_('OPA custom authentication headers'),
default={},
help_text=_('Optional custom headers included in requests to the OPA server. Defaults to empty dictionary ({}).'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_REQUEST_TIMEOUT',
field_class=fields.FloatField,
label=_('OPA request timeout'),
default=1.5,
help_text=_('The number of seconds after which the connection to the OPA server will time out. Defaults to 1.5 seconds.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_REQUEST_RETRIES',
field_class=fields.IntegerField,
label=_('OPA request retry count'),
default=2,
help_text=_('The number of retry attempts for connecting to the OPA server. Default is 2.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)

View File

@@ -88,10 +88,8 @@ class Scheduler:
# internally times are all referenced relative to startup time, add grace period # internally times are all referenced relative to startup time, add grace period
self.global_start = time.time() + 2.0 self.global_start = time.time() + 2.0
def get_and_mark_pending(self, reftime=None): def get_and_mark_pending(self):
if reftime is None: relative_time = time.time() - self.global_start
reftime = time.time() # mostly for tests
relative_time = reftime - self.global_start
to_run = [] to_run = []
for job in self.jobs: for job in self.jobs:
if job.due_to_run(relative_time): if job.due_to_run(relative_time):
@@ -100,10 +98,8 @@ class Scheduler:
job.mark_run(relative_time) job.mark_run(relative_time)
return to_run return to_run
def time_until_next_run(self, reftime=None): def time_until_next_run(self):
if reftime is None: relative_time = time.time() - self.global_start
reftime = time.time() # mostly for tests
relative_time = reftime - self.global_start
next_job = min(self.jobs, key=lambda j: j.next_run) next_job = min(self.jobs, key=lambda j: j.next_run)
delta = next_job.next_run - relative_time delta = next_job.next_run - relative_time
if delta <= 0.1: if delta <= 0.1:
@@ -119,11 +115,10 @@ class Scheduler:
def debug(self, *args, **kwargs): def debug(self, *args, **kwargs):
data = dict() data = dict()
data['title'] = 'Scheduler status' data['title'] = 'Scheduler status'
reftime = time.time()
now = datetime.fromtimestamp(reftime).strftime('%Y-%m-%d %H:%M:%S UTC') now = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S UTC')
start_time = datetime.fromtimestamp(self.global_start).strftime('%Y-%m-%d %H:%M:%S UTC') start_time = datetime.fromtimestamp(self.global_start).strftime('%Y-%m-%d %H:%M:%S UTC')
relative_time = reftime - self.global_start relative_time = time.time() - self.global_start
data['started_time'] = start_time data['started_time'] = start_time
data['current_time'] = now data['current_time'] = now
data['current_time_relative'] = round(relative_time, 3) data['current_time_relative'] = round(relative_time, 3)

View File

@@ -7,7 +7,6 @@ import time
import traceback import traceback
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
import json
import collections import collections
from multiprocessing import Process from multiprocessing import Process
@@ -26,10 +25,7 @@ from ansible_base.lib.logging.runtime import log_excess_runtime
from awx.main.models import UnifiedJob from awx.main.models import UnifiedJob
from awx.main.dispatch import reaper from awx.main.dispatch import reaper
from awx.main.utils.common import get_mem_effective_capacity, get_corrected_memory, get_corrected_cpu, get_cpu_effective_capacity from awx.main.utils.common import convert_mem_str_to_bytes, get_mem_effective_capacity
# ansible-runner
from ansible_runner.utils.capacity import get_mem_in_bytes, get_cpu_count
if 'run_callback_receiver' in sys.argv: if 'run_callback_receiver' in sys.argv:
logger = logging.getLogger('awx.main.commands.run_callback_receiver') logger = logging.getLogger('awx.main.commands.run_callback_receiver')
@@ -311,41 +307,6 @@ class WorkerPool(object):
logger.exception('could not kill {}'.format(worker.pid)) logger.exception('could not kill {}'.format(worker.pid))
def get_auto_max_workers():
"""Method we normally rely on to get max_workers
Uses almost same logic as Instance.local_health_check
The important thing is to be MORE than Instance.capacity
so that the task-manager does not over-schedule this node
Ideally we would just use the capacity from the database plus reserve workers,
but this poses some bootstrap problems where OCP task containers
register themselves after startup
"""
# Get memory from ansible-runner
total_memory_gb = get_mem_in_bytes()
# This may replace memory calculation with a user override
corrected_memory = get_corrected_memory(total_memory_gb)
# Get same number as max forks based on memory, this function takes memory as bytes
mem_capacity = get_mem_effective_capacity(corrected_memory, is_control_node=True)
# Follow same process for CPU capacity constraint
cpu_count = get_cpu_count()
corrected_cpu = get_corrected_cpu(cpu_count)
cpu_capacity = get_cpu_effective_capacity(corrected_cpu, is_control_node=True)
# Here is what is different from health checks,
auto_max = max(mem_capacity, cpu_capacity)
# add magic number of extra workers to ensure
# we have a few extra workers to run the heartbeat
auto_max += 7
return auto_max
class AutoscalePool(WorkerPool): class AutoscalePool(WorkerPool):
""" """
An extended pool implementation that automatically scales workers up and An extended pool implementation that automatically scales workers up and
@@ -359,7 +320,19 @@ class AutoscalePool(WorkerPool):
super(AutoscalePool, self).__init__(*args, **kwargs) super(AutoscalePool, self).__init__(*args, **kwargs)
if self.max_workers is None: if self.max_workers is None:
self.max_workers = get_auto_max_workers() settings_absmem = getattr(settings, 'SYSTEM_TASK_ABS_MEM', None)
if settings_absmem is not None:
# There are 1073741824 bytes in a gigabyte. Convert bytes to gigabytes by dividing by 2**30
total_memory_gb = convert_mem_str_to_bytes(settings_absmem) // 2**30
else:
total_memory_gb = (psutil.virtual_memory().total >> 30) + 1 # noqa: round up
# Get same number as max forks based on memory, this function takes memory as bytes
self.max_workers = get_mem_effective_capacity(total_memory_gb * 2**30)
# add magic prime number of extra workers to ensure
# we have a few extra workers to run the heartbeat
self.max_workers += 7
# max workers can't be less than min_workers # max workers can't be less than min_workers
self.max_workers = max(self.min_workers, self.max_workers) self.max_workers = max(self.min_workers, self.max_workers)
@@ -373,9 +346,6 @@ class AutoscalePool(WorkerPool):
self.scale_up_ct = 0 self.scale_up_ct = 0
self.worker_count_max = 0 self.worker_count_max = 0
# last time we wrote current tasks, to avoid too much log spam
self.last_task_list_log = time.monotonic()
def produce_subsystem_metrics(self, metrics_object): def produce_subsystem_metrics(self, metrics_object):
metrics_object.set('dispatcher_pool_scale_up_events', self.scale_up_ct) metrics_object.set('dispatcher_pool_scale_up_events', self.scale_up_ct)
metrics_object.set('dispatcher_pool_active_task_count', sum(len(w.managed_tasks) for w in self.workers)) metrics_object.set('dispatcher_pool_active_task_count', sum(len(w.managed_tasks) for w in self.workers))
@@ -493,14 +463,6 @@ class AutoscalePool(WorkerPool):
self.worker_count_max = new_worker_ct self.worker_count_max = new_worker_ct
return ret return ret
@staticmethod
def fast_task_serialization(current_task):
try:
return str(current_task.get('task')) + ' - ' + str(sorted(current_task.get('args', []))) + ' - ' + str(sorted(current_task.get('kwargs', {})))
except Exception:
# just make sure this does not make things worse
return str(current_task)
def write(self, preferred_queue, body): def write(self, preferred_queue, body):
if 'guid' in body: if 'guid' in body:
set_guid(body['guid']) set_guid(body['guid'])
@@ -522,15 +484,6 @@ class AutoscalePool(WorkerPool):
if isinstance(body, dict): if isinstance(body, dict):
task_name = body.get('task') task_name = body.get('task')
logger.warning(f'Workers maxed, queuing {task_name}, load: {sum(len(w.managed_tasks) for w in self.workers)} / {len(self.workers)}') logger.warning(f'Workers maxed, queuing {task_name}, load: {sum(len(w.managed_tasks) for w in self.workers)} / {len(self.workers)}')
# Once every 10 seconds write out task list for debugging
if time.monotonic() - self.last_task_list_log >= 10.0:
task_counts = {}
for worker in self.workers:
task_slug = self.fast_task_serialization(worker.current_task)
task_counts.setdefault(task_slug, 0)
task_counts[task_slug] += 1
logger.info(f'Running tasks by count:\n{json.dumps(task_counts, indent=2)}')
self.last_task_list_log = time.monotonic()
return super(AutoscalePool, self).write(preferred_queue, body) return super(AutoscalePool, self).write(preferred_queue, body)
except Exception: except Exception:
for conn in connections.all(): for conn in connections.all():

View File

@@ -15,7 +15,6 @@ from datetime import timedelta
from django import db from django import db
from django.conf import settings from django.conf import settings
import redis.exceptions
from ansible_base.lib.logging.runtime import log_excess_runtime from ansible_base.lib.logging.runtime import log_excess_runtime
@@ -131,13 +130,10 @@ class AWXConsumerBase(object):
@log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2) @log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2)
def record_statistics(self): def record_statistics(self):
if time.time() - self.last_stats > 1: # buffer stat recording to once per second if time.time() - self.last_stats > 1: # buffer stat recording to once per second
save_data = self.pool.debug()
try: try:
self.redis.set(f'awx_{self.name}_statistics', save_data) self.redis.set(f'awx_{self.name}_statistics', self.pool.debug())
except redis.exceptions.ConnectionError as exc:
logger.warning(f'Redis connection error saving {self.name} status data:\n{exc}\nmissed data:\n{save_data}')
except Exception: except Exception:
logger.exception(f"Unknown redis error saving {self.name} status data:\nmissed data:\n{save_data}") logger.exception(f"encountered an error communicating with redis to store {self.name} statistics")
self.last_stats = time.time() self.last_stats = time.time()
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
@@ -193,10 +189,7 @@ class AWXConsumerPG(AWXConsumerBase):
current_time = time.time() current_time = time.time()
self.pool.produce_subsystem_metrics(self.subsystem_metrics) self.pool.produce_subsystem_metrics(self.subsystem_metrics)
self.subsystem_metrics.set('dispatcher_availability', self.listen_cumulative_time / (current_time - self.last_metrics_gather)) self.subsystem_metrics.set('dispatcher_availability', self.listen_cumulative_time / (current_time - self.last_metrics_gather))
try: self.subsystem_metrics.pipe_execute()
self.subsystem_metrics.pipe_execute()
except redis.exceptions.ConnectionError as exc:
logger.warning(f'Redis connection error saving dispatcher metrics, error:\n{exc}')
self.listen_cumulative_time = 0.0 self.listen_cumulative_time = 0.0
self.last_metrics_gather = current_time self.last_metrics_gather = current_time
@@ -212,11 +205,7 @@ class AWXConsumerPG(AWXConsumerBase):
except Exception as exc: except Exception as exc:
logger.warning(f'Failed to save dispatcher statistics {exc}') logger.warning(f'Failed to save dispatcher statistics {exc}')
# Everything benchmarks to the same original time, so that skews due to for job in self.scheduler.get_and_mark_pending():
# runtime of the actions, themselves, do not mess up scheduling expectations
reftime = time.time()
for job in self.scheduler.get_and_mark_pending(reftime=reftime):
if 'control' in job.data: if 'control' in job.data:
try: try:
job.data['control']() job.data['control']()
@@ -233,12 +222,12 @@ class AWXConsumerPG(AWXConsumerBase):
self.listen_start = time.time() self.listen_start = time.time()
return self.scheduler.time_until_next_run(reftime=reftime) return self.scheduler.time_until_next_run()
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
super(AWXConsumerPG, self).run(*args, **kwargs) super(AWXConsumerPG, self).run(*args, **kwargs)
logger.info(f"Running {self.name}, workers min={self.pool.min_workers} max={self.pool.max_workers}, listening to queues {self.queues}") logger.info(f"Running worker {self.name} listening to queues {self.queues}")
init = False init = False
while True: while True:

View File

@@ -86,7 +86,6 @@ class CallbackBrokerWorker(BaseWorker):
return os.getpid() return os.getpid()
def read(self, queue): def read(self, queue):
has_redis_error = False
try: try:
res = self.redis.blpop(self.queue_name, timeout=1) res = self.redis.blpop(self.queue_name, timeout=1)
if res is None: if res is None:
@@ -96,21 +95,14 @@ class CallbackBrokerWorker(BaseWorker):
self.subsystem_metrics.inc('callback_receiver_events_popped_redis', 1) self.subsystem_metrics.inc('callback_receiver_events_popped_redis', 1)
self.subsystem_metrics.inc('callback_receiver_events_in_memory', 1) self.subsystem_metrics.inc('callback_receiver_events_in_memory', 1)
return json.loads(res[1]) return json.loads(res[1])
except redis.exceptions.ConnectionError as exc:
# Low noise log, because very common and many workers will write this
logger.error(f"redis connection error: {exc}")
has_redis_error = True
time.sleep(5)
except redis.exceptions.RedisError: except redis.exceptions.RedisError:
logger.exception("encountered an error communicating with redis") logger.exception("encountered an error communicating with redis")
has_redis_error = True
time.sleep(1) time.sleep(1)
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
logger.exception("failed to decode JSON message from redis") logger.exception("failed to decode JSON message from redis")
finally: finally:
if not has_redis_error: self.record_statistics()
self.record_statistics() self.record_read_metrics()
self.record_read_metrics()
return {'event': 'FLUSH'} return {'event': 'FLUSH'}

View File

@@ -38,12 +38,5 @@ class PostRunError(Exception):
super(PostRunError, self).__init__(msg) super(PostRunError, self).__init__(msg)
class PolicyEvaluationError(Exception):
def __init__(self, msg, status='failed', tb=''):
self.status = status
self.tb = tb
super(PolicyEvaluationError, self).__init__(msg)
class ReceptorNodeNotFound(RuntimeError): class ReceptorNodeNotFound(RuntimeError):
pass pass

View File

@@ -33,7 +33,6 @@ from awx.main.utils.safe_yaml import sanitize_jinja
from awx.main.models.rbac import batch_role_ancestor_rebuilding from awx.main.models.rbac import batch_role_ancestor_rebuilding
from awx.main.utils import ignore_inventory_computed_fields, get_licenser from awx.main.utils import ignore_inventory_computed_fields, get_licenser
from awx.main.utils.execution_environments import get_default_execution_environment from awx.main.utils.execution_environments import get_default_execution_environment
from awx.main.utils.inventory_vars import update_group_variables
from awx.main.signals import disable_activity_stream from awx.main.signals import disable_activity_stream
from awx.main.constants import STANDARD_INVENTORY_UPDATE_ENV from awx.main.constants import STANDARD_INVENTORY_UPDATE_ENV
@@ -458,19 +457,19 @@ class Command(BaseCommand):
""" """
Update inventory variables from "all" group. Update inventory variables from "all" group.
""" """
# TODO: We disable variable overwrite here in case user-defined inventory variables get
# mangled. But we still need to figure out a better way of processing multiple inventory
# update variables mixing with each other.
# issue for this: https://github.com/ansible/awx/issues/11623
if self.inventory.kind == 'constructed' and self.inventory_source.overwrite_vars: if self.inventory.kind == 'constructed' and self.inventory_source.overwrite_vars:
# NOTE: we had to add a exception case to not merge variables # NOTE: we had to add a exception case to not merge variables
# to make constructed inventory coherent # to make constructed inventory coherent
db_variables = self.all_group.variables db_variables = self.all_group.variables
else: else:
db_variables = update_group_variables( db_variables = self.inventory.variables_dict
group_id=None, # `None` denotes the 'all' group (which doesn't have a pk). db_variables.update(self.all_group.variables)
newvars=self.all_group.variables,
dbvars=self.inventory.variables_dict,
invsrc_id=self.inventory_source.id,
inventory_id=self.inventory.id,
overwrite_vars=self.overwrite_vars,
)
if db_variables != self.inventory.variables_dict: if db_variables != self.inventory.variables_dict:
self.inventory.variables = json.dumps(db_variables) self.inventory.variables = json.dumps(db_variables)
self.inventory.save(update_fields=['variables']) self.inventory.save(update_fields=['variables'])

View File

@@ -1,13 +1,10 @@
# Copyright (c) 2015 Ansible, Inc. # Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved. # All Rights Reserved.
import redis
from django.conf import settings from django.conf import settings
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand
import redis.exceptions
from awx.main.analytics.subsystem_metrics import CallbackReceiverMetricsServer from awx.main.analytics.subsystem_metrics import CallbackReceiverMetricsServer
from awx.main.dispatch.control import Control from awx.main.dispatch.control import Control
from awx.main.dispatch.worker import AWXConsumerRedis, CallbackBrokerWorker from awx.main.dispatch.worker import AWXConsumerRedis, CallbackBrokerWorker
@@ -30,10 +27,7 @@ class Command(BaseCommand):
return return
consumer = None consumer = None
try: CallbackReceiverMetricsServer().start()
CallbackReceiverMetricsServer().start()
except redis.exceptions.ConnectionError as exc:
raise CommandError(f'Callback receiver could not connect to redis, error: {exc}')
try: try:
consumer = AWXConsumerRedis( consumer = AWXConsumerRedis(

View File

@@ -3,10 +3,8 @@
import logging import logging
import yaml import yaml
import redis
from django.conf import settings from django.conf import settings
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.control import Control from awx.main.dispatch.control import Control
@@ -65,10 +63,7 @@ class Command(BaseCommand):
consumer = None consumer = None
try: DispatcherMetricsServer().start()
DispatcherMetricsServer().start()
except redis.exceptions.ConnectionError as exc:
raise CommandError(f'Dispatcher could not connect to redis, error: {exc}')
try: try:
queues = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()] queues = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()]

View File

@@ -1,46 +0,0 @@
# Generated by Django 4.2.18 on 2025-03-17 16:10
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('main', '0196_indirect_managed_node_audit'),
]
operations = [
migrations.AddField(
model_name='inventory',
name='opa_query_path',
field=models.CharField(
blank=True,
default=None,
help_text='The query path for the OPA policy to evaluate prior to job execution. The query path should be formatted as package/rule.',
max_length=128,
null=True,
),
),
migrations.AddField(
model_name='jobtemplate',
name='opa_query_path',
field=models.CharField(
blank=True,
default=None,
help_text='The query path for the OPA policy to evaluate prior to job execution. The query path should be formatted as package/rule.',
max_length=128,
null=True,
),
),
migrations.AddField(
model_name='organization',
name='opa_query_path',
field=models.CharField(
blank=True,
default=None,
help_text='The query path for the OPA policy to evaluate prior to job execution. The query path should be formatted as package/rule.',
max_length=128,
null=True,
),
),
]

View File

@@ -5,7 +5,7 @@ from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('main', '0200_template_name_constraint'), ('main', '0196_indirect_managed_node_audit'),
] ]
operations = [ operations = [

View File

@@ -1,61 +0,0 @@
# Generated by Django 4.2.18 on 2025-02-27 20:35
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [('main', '0197_add_opa_query_path')]
operations = [
migrations.AlterField(
model_name='inventorysource',
name='source',
field=models.CharField(
choices=[
('file', 'File, Directory or Script'),
('constructed', 'Template additional groups and hostvars at runtime'),
('scm', 'Sourced from a Project'),
('ec2', 'Amazon EC2'),
('gce', 'Google Compute Engine'),
('azure_rm', 'Microsoft Azure Resource Manager'),
('vmware', 'VMware vCenter'),
('vmware_esxi', 'VMware ESXi'),
('satellite6', 'Red Hat Satellite 6'),
('openstack', 'OpenStack'),
('rhv', 'Red Hat Virtualization'),
('controller', 'Red Hat Ansible Automation Platform'),
('insights', 'Red Hat Insights'),
('terraform', 'Terraform State'),
('openshift_virtualization', 'OpenShift Virtualization'),
],
default=None,
max_length=32,
),
),
migrations.AlterField(
model_name='inventoryupdate',
name='source',
field=models.CharField(
choices=[
('file', 'File, Directory or Script'),
('constructed', 'Template additional groups and hostvars at runtime'),
('scm', 'Sourced from a Project'),
('ec2', 'Amazon EC2'),
('gce', 'Google Compute Engine'),
('azure_rm', 'Microsoft Azure Resource Manager'),
('vmware', 'VMware vCenter'),
('vmware_esxi', 'VMware ESXi'),
('satellite6', 'Red Hat Satellite 6'),
('openstack', 'OpenStack'),
('rhv', 'Red Hat Virtualization'),
('controller', 'Red Hat Ansible Automation Platform'),
('insights', 'Red Hat Insights'),
('terraform', 'Terraform State'),
('openshift_virtualization', 'OpenShift Virtualization'),
],
default=None,
max_length=32,
),
),
]

View File

@@ -5,7 +5,7 @@ from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('main', '0201_delete_profile'), ('main', '0197_delete_profile'),
] ]
operations = [ operations = [

View File

@@ -6,7 +6,7 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('main', '0202_remove_sso_app_content'), ('main', '0198_remove_sso_app_content'),
] ]
operations = [ operations = [

View File

@@ -1,32 +0,0 @@
# Generated by Django 4.2.20 on 2025-04-24 09:08
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('main', '0198_alter_inventorysource_source_and_more'),
]
operations = [
migrations.CreateModel(
name='InventoryGroupVariablesWithHistory',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('variables', models.JSONField()),
('group', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='inventory_group_variables', to='main.group')),
(
'inventory',
models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='inventory_group_variables', to='main.inventory'),
),
],
),
migrations.AddConstraint(
model_name='inventorygroupvariableswithhistory',
constraint=models.UniqueConstraint(
fields=('inventory', 'group'), name='unique_inventory_group', violation_error_message='Inventory/Group combination must be unique.'
),
),
]

View File

@@ -6,7 +6,7 @@ from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('main', '0203_alter_inventorysource_source_and_more'), ('main', '0199_alter_inventorysource_source_and_more'),
] ]
operations = [ operations = [

View File

@@ -1,50 +0,0 @@
# Generated by Django 4.2.20 on 2025-04-22 15:54
import logging
from django.db import migrations, models
from awx.main.migrations._db_constraints import _rename_duplicates
logger = logging.getLogger(__name__)
def rename_jts(apps, schema_editor):
cls = apps.get_model('main', 'JobTemplate')
_rename_duplicates(cls)
def rename_projects(apps, schema_editor):
cls = apps.get_model('main', 'Project')
_rename_duplicates(cls)
def change_inventory_source_org_unique(apps, schema_editor):
cls = apps.get_model('main', 'InventorySource')
r = cls.objects.update(org_unique=False)
logger.info(f'Set database constraint rule for {r} inventory source objects')
class Migration(migrations.Migration):
dependencies = [
('main', '0199_inventorygroupvariableswithhistory_and_more'),
]
operations = [
migrations.RunPython(rename_jts, migrations.RunPython.noop),
migrations.RunPython(rename_projects, migrations.RunPython.noop),
migrations.AddField(
model_name='unifiedjobtemplate',
name='org_unique',
field=models.BooleanField(blank=True, default=True, editable=False, help_text='Used internally to selectively enforce database constraint on name'),
),
migrations.RunPython(change_inventory_source_org_unique, migrations.RunPython.noop),
migrations.AddConstraint(
model_name='unifiedjobtemplate',
constraint=models.UniqueConstraint(
condition=models.Q(('org_unique', True)), fields=('polymorphic_ctype', 'name', 'organization'), name='ujt_hard_name_constraint'
),
),
]

View File

@@ -8,7 +8,7 @@ from awx.main.migrations._create_system_jobs import delete_clear_tokens_sjt
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('main', '0204_alter_oauth2application_unique_together_and_more'), ('main', '0200_alter_oauth2application_unique_together_and_more'),
] ]
operations = [ operations = [

View File

@@ -1,22 +0,0 @@
# Generated by Django 4.2.20 on 2025-05-22 08:57
from decimal import Decimal
import django.core.validators
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('main', '0205_delete_token_cleanup_job'),
]
operations = [
migrations.AlterField(
model_name='instance',
name='capacity_adjustment',
field=models.DecimalField(
decimal_places=2, default=Decimal('0.75'), max_digits=3, validators=[django.core.validators.MinValueValidator(Decimal('0'))]
),
),
]

View File

@@ -1,25 +0,0 @@
import logging
from django.db.models import Count
logger = logging.getLogger(__name__)
def _rename_duplicates(cls):
field = cls._meta.get_field('name')
max_len = field.max_length
for organization_id in cls.objects.order_by().values_list('organization_id', flat=True).distinct():
duplicate_data = cls.objects.values('name').filter(organization_id=organization_id).annotate(count=Count('name')).order_by().filter(count__gt=1)
for data in duplicate_data:
name = data['name']
for idx, ujt in enumerate(cls.objects.filter(name=name, organization_id=organization_id).order_by('created')):
if idx > 0:
suffix = f'_dup{idx}'
max_chars = max_len - len(suffix)
if len(ujt.name) >= max_chars:
ujt.name = ujt.name[:max_chars] + suffix
else:
ujt.name = ujt.name + suffix
logger.info(f'Renaming duplicate {cls._meta.model_name} to `{ujt.name}` because of duplicate name entry')
ujt.save(update_fields=['name'])

View File

@@ -33,7 +33,6 @@ from awx.main.models.inventory import ( # noqa
InventorySource, InventorySource,
InventoryUpdate, InventoryUpdate,
SmartInventoryMembership, SmartInventoryMembership,
InventoryGroupVariablesWithHistory,
) )
from awx.main.models.jobs import ( # noqa from awx.main.models.jobs import ( # noqa
Job, Job,

View File

@@ -550,10 +550,10 @@ class CredentialType(CommonModelNameNotUnique):
# TODO: User "side-loaded" credential custom_injectors isn't supported # TODO: User "side-loaded" credential custom_injectors isn't supported
ManagedCredentialType.registry[ns] = SimpleNamespace(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs, backend=plugin.backend) ManagedCredentialType.registry[ns] = SimpleNamespace(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs, backend=plugin.backend)
def inject_credential(self, credential, env, safe_env, args, private_data_dir, container_root=None): def inject_credential(self, credential, env, safe_env, args, private_data_dir):
from awx_plugins.interfaces._temporary_private_inject_api import inject_credential from awx_plugins.interfaces._temporary_private_inject_api import inject_credential
inject_credential(self, credential, env, safe_env, args, private_data_dir, container_root=container_root) inject_credential(self, credential, env, safe_env, args, private_data_dir)
class CredentialTypeHelper: class CredentialTypeHelper:

View File

@@ -24,7 +24,6 @@ from awx.main.managers import DeferJobCreatedManager
from awx.main.constants import MINIMAL_EVENTS from awx.main.constants import MINIMAL_EVENTS
from awx.main.models.base import CreatedModifiedModel from awx.main.models.base import CreatedModifiedModel
from awx.main.utils import ignore_inventory_computed_fields, camelcase_to_underscore from awx.main.utils import ignore_inventory_computed_fields, camelcase_to_underscore
from awx.main.utils.db import bulk_update_sorted_by_id
analytics_logger = logging.getLogger('awx.analytics.job_events') analytics_logger = logging.getLogger('awx.analytics.job_events')
@@ -566,6 +565,7 @@ class JobEvent(BasePlaybookEvent):
summaries = dict() summaries = dict()
updated_hosts_list = list() updated_hosts_list = list()
for host in hostnames: for host in hostnames:
updated_hosts_list.append(host.lower())
host_id = host_map.get(host) host_id = host_map.get(host)
if host_id not in existing_host_ids: if host_id not in existing_host_ids:
host_id = None host_id = None
@@ -582,12 +582,6 @@ class JobEvent(BasePlaybookEvent):
summary.failed = bool(summary.dark or summary.failures) summary.failed = bool(summary.dark or summary.failures)
summaries[(host_id, host)] = summary summaries[(host_id, host)] = summary
# do not count dark / unreachable hosts as updated
if not bool(summary.dark):
updated_hosts_list.append(host.lower())
else:
logger.warning(f'host {host.lower()} is dark / unreachable, not marking it as updated')
JobHostSummary.objects.bulk_create(summaries.values()) JobHostSummary.objects.bulk_create(summaries.values())
# update the last_job_id and last_job_host_summary_id # update the last_job_id and last_job_host_summary_id
@@ -603,7 +597,7 @@ class JobEvent(BasePlaybookEvent):
h.last_job_host_summary_id = host_mapping[h.id] h.last_job_host_summary_id = host_mapping[h.id]
updated_hosts.add(h) updated_hosts.add(h)
bulk_update_sorted_by_id(Host, updated_hosts, ['last_job_id', 'last_job_host_summary_id']) Host.objects.bulk_update(list(updated_hosts), ['last_job_id', 'last_job_host_summary_id'], batch_size=100)
# Create/update Host Metrics # Create/update Host Metrics
self._update_host_metrics(updated_hosts_list) self._update_host_metrics(updated_hosts_list)

View File

@@ -160,7 +160,7 @@ class Instance(HasPolicyEditsMixin, BaseModel):
default=100, default=100,
editable=False, editable=False,
) )
capacity_adjustment = models.DecimalField(default=Decimal(0.75), max_digits=3, decimal_places=2, validators=[MinValueValidator(Decimal(0.0))]) capacity_adjustment = models.DecimalField(default=Decimal(1.0), max_digits=3, decimal_places=2, validators=[MinValueValidator(Decimal(0.0))])
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
managed_by_policy = models.BooleanField(default=True) managed_by_policy = models.BooleanField(default=True)

View File

@@ -43,7 +43,6 @@ from awx.main.models.mixins import (
TaskManagerInventoryUpdateMixin, TaskManagerInventoryUpdateMixin,
RelatedJobsMixin, RelatedJobsMixin,
CustomVirtualEnvMixin, CustomVirtualEnvMixin,
OpaQueryPathMixin,
) )
from awx.main.models.notifications import ( from awx.main.models.notifications import (
NotificationTemplate, NotificationTemplate,
@@ -69,7 +68,7 @@ class InventoryConstructedInventoryMembership(models.Model):
) )
class Inventory(CommonModelNameNotUnique, ResourceMixin, RelatedJobsMixin, OpaQueryPathMixin): class Inventory(CommonModelNameNotUnique, ResourceMixin, RelatedJobsMixin):
""" """
an inventory source contains lists and hosts. an inventory source contains lists and hosts.
""" """
@@ -1120,10 +1119,8 @@ class InventorySource(UnifiedJobTemplate, InventorySourceOptions, CustomVirtualE
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
# if this is a new object, inherit organization from its inventory # if this is a new object, inherit organization from its inventory
if not self.pk: if not self.pk and self.inventory and self.inventory.organization_id and not self.organization_id:
self.org_unique = False # needed to exclude from unique (name, organization) constraint self.organization_id = self.inventory.organization_id
if self.inventory and self.inventory.organization_id and not self.organization_id:
self.organization_id = self.inventory.organization_id
# If update_fields has been specified, add our field names to it, # If update_fields has been specified, add our field names to it,
# if it hasn't been specified, then we're just doing a normal save. # if it hasn't been specified, then we're just doing a normal save.
@@ -1404,38 +1401,3 @@ class CustomInventoryScript(CommonModelNameNotUnique):
def get_absolute_url(self, request=None): def get_absolute_url(self, request=None):
return reverse('api:inventory_script_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:inventory_script_detail', kwargs={'pk': self.pk}, request=request)
class InventoryGroupVariablesWithHistory(models.Model):
"""
Represents the inventory variables of one inventory group.
The purpose of this model is to persist the update history of the group
variables. The update history is maintained in another class
(`InventoryGroupVariables`), this class here is just a container for the
database storage.
"""
class Meta:
constraints = [
# Do not allow the same inventory/group combination more than once.
models.UniqueConstraint(
fields=["inventory", "group"],
name="unique_inventory_group",
violation_error_message=_("Inventory/Group combination must be unique."),
),
]
inventory = models.ForeignKey(
'Inventory',
related_name='inventory_group_variables',
null=True,
on_delete=models.CASCADE,
)
group = models.ForeignKey( # `None` denotes the 'all'-group.
'Group',
related_name='inventory_group_variables',
null=True,
on_delete=models.CASCADE,
)
variables = models.JSONField() # The group variables, including their history.

View File

@@ -51,7 +51,6 @@ from awx.main.models.mixins import (
RelatedJobsMixin, RelatedJobsMixin,
WebhookMixin, WebhookMixin,
WebhookTemplateMixin, WebhookTemplateMixin,
OpaQueryPathMixin,
) )
from awx.main.constants import JOB_VARIABLE_PREFIXES from awx.main.constants import JOB_VARIABLE_PREFIXES
@@ -193,9 +192,7 @@ class JobOptions(BaseModel):
return needed return needed
class JobTemplate( class JobTemplate(UnifiedJobTemplate, JobOptions, SurveyJobTemplateMixin, ResourceMixin, CustomVirtualEnvMixin, RelatedJobsMixin, WebhookTemplateMixin):
UnifiedJobTemplate, JobOptions, SurveyJobTemplateMixin, ResourceMixin, CustomVirtualEnvMixin, RelatedJobsMixin, WebhookTemplateMixin, OpaQueryPathMixin
):
""" """
A job template is a reusable job definition for applying a project (with A job template is a reusable job definition for applying a project (with
playbook) to an inventory source with a given credential. playbook) to an inventory source with a given credential.
@@ -358,6 +355,26 @@ class JobTemplate(
update_fields.append('organization_id') update_fields.append('organization_id')
return super(JobTemplate, self).save(*args, **kwargs) return super(JobTemplate, self).save(*args, **kwargs)
def validate_unique(self, exclude=None):
"""Custom over-ride for JT specifically
because organization is inferred from project after full_clean is finished
thus the organization field is not yet set when validation happens
"""
errors = []
for ut in JobTemplate.SOFT_UNIQUE_TOGETHER:
kwargs = {'name': self.name}
if self.project:
kwargs['organization'] = self.project.organization_id
else:
kwargs['organization'] = None
qs = JobTemplate.objects.filter(**kwargs)
if self.pk:
qs = qs.exclude(pk=self.pk)
if qs.exists():
errors.append('%s with this (%s) combination already exists.' % (JobTemplate.__name__, ', '.join(set(ut) - {'polymorphic_ctype'})))
if errors:
raise ValidationError(errors)
def create_unified_job(self, **kwargs): def create_unified_job(self, **kwargs):
prevent_slicing = kwargs.pop('_prevent_slicing', False) prevent_slicing = kwargs.pop('_prevent_slicing', False)
slice_ct = self.get_effective_slice_ct(kwargs) slice_ct = self.get_effective_slice_ct(kwargs)
@@ -384,26 +401,6 @@ class JobTemplate(
WorkflowJobNode.objects.create(**create_kwargs) WorkflowJobNode.objects.create(**create_kwargs)
return job return job
def validate_unique(self, exclude=None):
"""Custom over-ride for JT specifically
because organization is inferred from project after full_clean is finished
thus the organization field is not yet set when validation happens
"""
errors = []
for ut in JobTemplate.SOFT_UNIQUE_TOGETHER:
kwargs = {'name': self.name}
if self.project:
kwargs['organization'] = self.project.organization_id
else:
kwargs['organization'] = None
qs = JobTemplate.objects.filter(**kwargs)
if self.pk:
qs = qs.exclude(pk=self.pk)
if qs.exists():
errors.append('%s with this (%s) combination already exists.' % (JobTemplate.__name__, ', '.join(set(ut) - {'polymorphic_ctype'})))
if errors:
raise ValidationError(errors)
def get_absolute_url(self, request=None): def get_absolute_url(self, request=None):
return reverse('api:job_template_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:job_template_detail', kwargs={'pk': self.pk}, request=request)

View File

@@ -42,7 +42,6 @@ __all__ = [
'TaskManagerInventoryUpdateMixin', 'TaskManagerInventoryUpdateMixin',
'ExecutionEnvironmentMixin', 'ExecutionEnvironmentMixin',
'CustomVirtualEnvMixin', 'CustomVirtualEnvMixin',
'OpaQueryPathMixin',
] ]
@@ -693,16 +692,3 @@ class WebhookMixin(models.Model):
logger.debug("Webhook status update sent.") logger.debug("Webhook status update sent.")
else: else:
logger.error("Posting webhook status failed, code: {}\n" "{}\nPayload sent: {}".format(response.status_code, response.text, json.dumps(data))) logger.error("Posting webhook status failed, code: {}\n" "{}\nPayload sent: {}".format(response.status_code, response.text, json.dumps(data)))
class OpaQueryPathMixin(models.Model):
class Meta:
abstract = True
opa_query_path = models.CharField(
max_length=128,
blank=True,
null=True,
default=None,
help_text=_("The query path for the OPA policy to evaluate prior to job execution. The query path should be formatted as package/rule."),
)

View File

@@ -22,12 +22,12 @@ from awx.main.models.rbac import (
ROLE_SINGLETON_SYSTEM_AUDITOR, ROLE_SINGLETON_SYSTEM_AUDITOR,
) )
from awx.main.models.unified_jobs import UnifiedJob from awx.main.models.unified_jobs import UnifiedJob
from awx.main.models.mixins import ResourceMixin, CustomVirtualEnvMixin, RelatedJobsMixin, OpaQueryPathMixin from awx.main.models.mixins import ResourceMixin, CustomVirtualEnvMixin, RelatedJobsMixin
__all__ = ['Organization', 'Team', 'UserSessionMembership'] __all__ = ['Organization', 'Team', 'UserSessionMembership']
class Organization(CommonModel, NotificationFieldsModel, ResourceMixin, CustomVirtualEnvMixin, RelatedJobsMixin, OpaQueryPathMixin): class Organization(CommonModel, NotificationFieldsModel, ResourceMixin, CustomVirtualEnvMixin, RelatedJobsMixin):
""" """
An organization is the basic unit of multi-tenancy divisions An organization is the basic unit of multi-tenancy divisions
""" """

View File

@@ -18,7 +18,6 @@ from collections import OrderedDict
# Django # Django
from django.conf import settings from django.conf import settings
from django.db import models, connection, transaction from django.db import models, connection, transaction
from django.db.models.constraints import UniqueConstraint
from django.core.exceptions import NON_FIELD_ERRORS from django.core.exceptions import NON_FIELD_ERRORS
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.utils.timezone import now from django.utils.timezone import now
@@ -112,10 +111,7 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, ExecutionEn
ordering = ('name',) ordering = ('name',)
# unique_together here is intentionally commented out. Please make sure sub-classes of this model # unique_together here is intentionally commented out. Please make sure sub-classes of this model
# contain at least this uniqueness restriction: SOFT_UNIQUE_TOGETHER = [('polymorphic_ctype', 'name')] # contain at least this uniqueness restriction: SOFT_UNIQUE_TOGETHER = [('polymorphic_ctype', 'name')]
# Unique name constraint - note that inventory source model is excluded from this constraint entirely # unique_together = [('polymorphic_ctype', 'name', 'organization')]
constraints = [
UniqueConstraint(fields=['polymorphic_ctype', 'name', 'organization'], condition=models.Q(org_unique=True), name='ujt_hard_name_constraint')
]
old_pk = models.PositiveIntegerField( old_pk = models.PositiveIntegerField(
null=True, null=True,
@@ -184,9 +180,6 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, ExecutionEn
) )
labels = models.ManyToManyField("Label", blank=True, related_name='%(class)s_labels') labels = models.ManyToManyField("Label", blank=True, related_name='%(class)s_labels')
instance_groups = OrderedManyToManyField('InstanceGroup', blank=True, through='UnifiedJobTemplateInstanceGroupMembership') instance_groups = OrderedManyToManyField('InstanceGroup', blank=True, through='UnifiedJobTemplateInstanceGroupMembership')
org_unique = models.BooleanField(
blank=True, default=True, editable=False, help_text=_('Used internally to selectively enforce database constraint on name')
)
def get_absolute_url(self, request=None): def get_absolute_url(self, request=None):
real_instance = self.get_real_instance() real_instance = self.get_real_instance()

View File

@@ -53,8 +53,8 @@ class GrafanaBackend(AWXBaseEmailBackend, CustomNotificationBase):
): ):
super(GrafanaBackend, self).__init__(fail_silently=fail_silently) super(GrafanaBackend, self).__init__(fail_silently=fail_silently)
self.grafana_key = grafana_key self.grafana_key = grafana_key
self.dashboardId = int(dashboardId) if dashboardId is not None and panelId != "" else None self.dashboardId = int(dashboardId) if dashboardId is not None else None
self.panelId = int(panelId) if panelId is not None and panelId != "" else None self.panelId = int(panelId) if panelId is not None else None
self.annotation_tags = annotation_tags if annotation_tags is not None else [] self.annotation_tags = annotation_tags if annotation_tags is not None else []
self.grafana_no_verify_ssl = grafana_no_verify_ssl self.grafana_no_verify_ssl = grafana_no_verify_ssl
self.isRegion = isRegion self.isRegion = isRegion
@@ -97,7 +97,6 @@ class GrafanaBackend(AWXBaseEmailBackend, CustomNotificationBase):
r = requests.post( r = requests.post(
"{}/api/annotations".format(m.recipients()[0]), json=grafana_data, headers=grafana_headers, verify=(not self.grafana_no_verify_ssl) "{}/api/annotations".format(m.recipients()[0]), json=grafana_data, headers=grafana_headers, verify=(not self.grafana_no_verify_ssl)
) )
if r.status_code >= 400: if r.status_code >= 400:
logger.error(smart_str(_("Error sending notification grafana: {}").format(r.status_code))) logger.error(smart_str(_("Error sending notification grafana: {}").format(r.status_code)))
if not self.fail_silently: if not self.fail_silently:

View File

@@ -174,9 +174,6 @@ class PodManager(object):
) )
pod_spec['spec']['containers'][0]['name'] = self.pod_name pod_spec['spec']['containers'][0]['name'] = self.pod_name
# Prevent mounting of service account token in job pods in order to prevent job pods from accessing the k8s API via in cluster service account auth
pod_spec['spec']['automountServiceAccountToken'] = False
return pod_spec return pod_spec

View File

@@ -10,8 +10,6 @@ import time
import sys import sys
import signal import signal
import redis
# Django # Django
from django.db import transaction from django.db import transaction
from django.utils.translation import gettext_lazy as _, gettext_noop from django.utils.translation import gettext_lazy as _, gettext_noop
@@ -122,8 +120,6 @@ class TaskBase:
self.subsystem_metrics.pipe_execute() self.subsystem_metrics.pipe_execute()
else: else:
logger.debug(f"skipping recording {self.prefix} metrics, last recorded {time_last_recorded} seconds ago") logger.debug(f"skipping recording {self.prefix} metrics, last recorded {time_last_recorded} seconds ago")
except redis.exceptions.ConnectionError as exc:
logger.warning(f"Redis connection error saving metrics for {self.prefix}, error: {exc}")
except Exception: except Exception:
logger.exception(f"Error saving metrics for {self.prefix}") logger.exception(f"Error saving metrics for {self.prefix}")

View File

@@ -6,15 +6,16 @@ import logging
# Django # Django
from django.conf import settings from django.conf import settings
from django.db.models.query import QuerySet
from django.utils.encoding import smart_str from django.utils.encoding import smart_str
from django.utils.timezone import now from django.utils.timezone import now
from django.db import OperationalError
# django-ansible-base # django-ansible-base
from ansible_base.lib.logging.runtime import log_excess_runtime from ansible_base.lib.logging.runtime import log_excess_runtime
# AWX # AWX
from awx.main.utils.db import bulk_update_sorted_by_id from awx.main.models.inventory import Host
from awx.main.models import Host
logger = logging.getLogger('awx.main.tasks.facts') logger = logging.getLogger('awx.main.tasks.facts')
@@ -22,51 +23,63 @@ system_tracking_logger = logging.getLogger('awx.analytics.system_tracking')
@log_excess_runtime(logger, debug_cutoff=0.01, msg='Inventory {inventory_id} host facts prepared for {written_ct} hosts, took {delta:.3f} s', add_log_data=True) @log_excess_runtime(logger, debug_cutoff=0.01, msg='Inventory {inventory_id} host facts prepared for {written_ct} hosts, took {delta:.3f} s', add_log_data=True)
def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_data=None): def start_fact_cache(hosts, destination, log_data, timeout=None, inventory_id=None):
log_data = log_data or {}
log_data['inventory_id'] = inventory_id log_data['inventory_id'] = inventory_id
log_data['written_ct'] = 0 log_data['written_ct'] = 0
hosts_cached = [] try:
os.makedirs(destination, mode=0o700)
# Create the fact_cache directory inside artifacts_dir except FileExistsError:
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache') pass
os.makedirs(fact_cache_dir, mode=0o700, exist_ok=True)
if timeout is None: if timeout is None:
timeout = settings.ANSIBLE_FACT_CACHE_TIMEOUT timeout = settings.ANSIBLE_FACT_CACHE_TIMEOUT
last_write_time = None if isinstance(hosts, QuerySet):
hosts = hosts.iterator()
last_filepath_written = None
for host in hosts: for host in hosts:
hosts_cached.append(host.name) if (not host.ansible_facts_modified) or (timeout and host.ansible_facts_modified < now() - datetime.timedelta(seconds=timeout)):
if not host.ansible_facts_modified or (timeout and host.ansible_facts_modified < now() - datetime.timedelta(seconds=timeout)):
continue # facts are expired - do not write them continue # facts are expired - do not write them
filepath = os.sep.join(map(str, [destination, host.name]))
filepath = os.path.join(fact_cache_dir, host.name) if not os.path.realpath(filepath).startswith(destination):
if not os.path.realpath(filepath).startswith(fact_cache_dir): system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
logger.error(f'facts for host {smart_str(host.name)} could not be cached')
continue continue
try: try:
with codecs.open(filepath, 'w', encoding='utf-8') as f: with codecs.open(filepath, 'w', encoding='utf-8') as f:
os.chmod(f.name, 0o600) os.chmod(f.name, 0o600)
json.dump(host.ansible_facts, f) json.dump(host.ansible_facts, f)
log_data['written_ct'] += 1 log_data['written_ct'] += 1
last_write_time = os.path.getmtime(filepath) last_filepath_written = filepath
except IOError: except IOError:
logger.error(f'facts for host {smart_str(host.name)} could not be cached') system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
continue continue
# make note of the time we wrote the last file so we can check if any file changed later
if last_filepath_written:
return os.path.getmtime(last_filepath_written)
return None
# Write summary file directly to the artifacts_dir
if inventory_id is not None: def raw_update_hosts(host_list):
summary_file = os.path.join(artifacts_dir, 'host_cache_summary.json') Host.objects.bulk_update(host_list, ['ansible_facts', 'ansible_facts_modified'])
summary_data = {
'last_write_time': last_write_time,
'hosts_cached': hosts_cached, def update_hosts(host_list, max_tries=5):
'written_ct': log_data['written_ct'], if not host_list:
} return
with open(summary_file, 'w', encoding='utf-8') as f: for i in range(max_tries):
json.dump(summary_data, f, indent=2) try:
raw_update_hosts(host_list)
except OperationalError as exc:
# Deadlocks can happen if this runs at the same time as another large query
# inventory updates and updating last_job_host_summary are candidates for conflict
# but these would resolve easily on a retry
if i + 1 < max_tries:
logger.info(f'OperationalError (suspected deadlock) saving host facts retry {i}, message: {exc}')
continue
else:
raise
break
@log_excess_runtime( @log_excess_runtime(
@@ -75,54 +88,35 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
msg='Inventory {inventory_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s', msg='Inventory {inventory_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s',
add_log_data=True, add_log_data=True,
) )
def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=None): def finish_fact_cache(hosts, destination, facts_write_time, log_data, job_id=None, inventory_id=None):
log_data = log_data or {}
log_data['inventory_id'] = inventory_id log_data['inventory_id'] = inventory_id
log_data['updated_ct'] = 0 log_data['updated_ct'] = 0
log_data['unmodified_ct'] = 0 log_data['unmodified_ct'] = 0
log_data['cleared_ct'] = 0 log_data['cleared_ct'] = 0
# The summary file is directly inside the artifacts dir
summary_path = os.path.join(artifacts_dir, 'host_cache_summary.json')
if not os.path.exists(summary_path):
logger.error(f'Missing summary file at {summary_path}')
return
try: if isinstance(hosts, QuerySet):
with open(summary_path, 'r', encoding='utf-8') as f: hosts = hosts.iterator()
summary = json.load(f)
facts_write_time = os.path.getmtime(summary_path) # After successful read
except (json.JSONDecodeError, OSError) as e:
logger.error(f'Error reading summary file at {summary_path}: {e}')
return
host_names = summary.get('hosts_cached', [])
hosts_cached = Host.objects.filter(name__in=host_names).order_by('id').iterator()
# Path where individual fact files were written
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
hosts_to_update = [] hosts_to_update = []
for host in hosts:
for host in hosts_cached: filepath = os.sep.join(map(str, [destination, host.name]))
filepath = os.path.join(fact_cache_dir, host.name) if not os.path.realpath(filepath).startswith(destination):
if not os.path.realpath(filepath).startswith(fact_cache_dir): system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
logger.error(f'Invalid path for facts file: {filepath}')
continue continue
if os.path.exists(filepath): if os.path.exists(filepath):
# If the file changed since we wrote the last facts file, pre-playbook run... # If the file changed since we wrote the last facts file, pre-playbook run...
modified = os.path.getmtime(filepath) modified = os.path.getmtime(filepath)
if not facts_write_time or modified >= facts_write_time: if (not facts_write_time) or modified > facts_write_time:
try: with codecs.open(filepath, 'r', encoding='utf-8') as f:
with codecs.open(filepath, 'r', encoding='utf-8') as f: try:
ansible_facts = json.load(f) ansible_facts = json.load(f)
except ValueError: except ValueError:
continue continue
if ansible_facts != host.ansible_facts:
host.ansible_facts = ansible_facts host.ansible_facts = ansible_facts
host.ansible_facts_modified = now() host.ansible_facts_modified = now()
hosts_to_update.append(host) hosts_to_update.append(host)
logger.info( system_tracking_logger.info(
f'New fact for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}', 'New fact for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)),
extra=dict( extra=dict(
inventory_id=host.inventory.id, inventory_id=host.inventory.id,
host_name=host.name, host_name=host.name,
@@ -132,21 +126,16 @@ def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=No
), ),
) )
log_data['updated_ct'] += 1 log_data['updated_ct'] += 1
else:
log_data['unmodified_ct'] += 1
else: else:
log_data['unmodified_ct'] += 1 log_data['unmodified_ct'] += 1
else: else:
# if the file goes missing, ansible removed it (likely via clear_facts) # if the file goes missing, ansible removed it (likely via clear_facts)
# if the file goes missing, but the host has not started facts, then we should not clear the facts
host.ansible_facts = {} host.ansible_facts = {}
host.ansible_facts_modified = now() host.ansible_facts_modified = now()
hosts_to_update.append(host) hosts_to_update.append(host)
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}') system_tracking_logger.info('Facts cleared for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)))
log_data['cleared_ct'] += 1 log_data['cleared_ct'] += 1
if len(hosts_to_update) > 100:
if len(hosts_to_update) >= 100: update_hosts(hosts_to_update)
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
hosts_to_update = [] hosts_to_update = []
update_hosts(hosts_to_update)
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])

View File

@@ -45,46 +45,26 @@ def build_indirect_host_data(job: Job, job_event_queries: dict[str, dict[str, st
facts_missing_logged = False facts_missing_logged = False
unhashable_facts_logged = False unhashable_facts_logged = False
job_event_queries_fqcn = {}
for query_k, query_v in job_event_queries.items():
if len(parts := query_k.split('.')) != 3:
logger.info(f"Skiping malformed query '{query_k}'. Expected to be of the form 'a.b.c'")
continue
if parts[2] != '*':
continue
job_event_queries_fqcn['.'.join(parts[0:2])] = query_v
for event in job.job_events.filter(event_data__isnull=False).iterator(): for event in job.job_events.filter(event_data__isnull=False).iterator():
if 'res' not in event.event_data: if 'res' not in event.event_data:
continue continue
if not (resolved_action := event.event_data.get('resolved_action', None)): if 'resolved_action' not in event.event_data or event.event_data['resolved_action'] not in job_event_queries.keys():
continue continue
if len(resolved_action_parts := resolved_action.split('.')) != 3: resolved_action = event.event_data['resolved_action']
logger.debug(f"Malformed invocation module name '{resolved_action}'. Expected to be of the form 'a.b.c'")
continue
resolved_action_fqcn = '.'.join(resolved_action_parts[0:2]) # We expect a dict with a 'query' key for the resolved_action
if 'query' not in job_event_queries[resolved_action]:
# Match module invocation to collection queries
# First match against fully qualified query names i.e. a.b.c
# Then try and match against wildcard queries i.e. a.b.*
if not (jq_str_for_event := job_event_queries.get(resolved_action, job_event_queries_fqcn.get(resolved_action_fqcn, {})).get('query')):
continue continue
# Recall from cache, or process the jq expression, and loop over the jq results # Recall from cache, or process the jq expression, and loop over the jq results
jq_str_for_event = job_event_queries[resolved_action]['query']
if jq_str_for_event not in compiled_jq_expressions: if jq_str_for_event not in compiled_jq_expressions:
compiled_jq_expressions[resolved_action] = jq.compile(jq_str_for_event) compiled_jq_expressions[resolved_action] = jq.compile(jq_str_for_event)
compiled_jq = compiled_jq_expressions[resolved_action] compiled_jq = compiled_jq_expressions[resolved_action]
for data in compiled_jq.input(event.event_data['res']).all():
try:
data_source = compiled_jq.input(event.event_data['res']).all()
except Exception as e:
logger.warning(f'error for module {resolved_action} and data {event.event_data["res"]}: {e}')
continue
for data in data_source:
# From this jq result (specific to a single Ansible module), get index information about this host record # From this jq result (specific to a single Ansible module), get index information about this host record
if not data.get('canonical_facts'): if not data.get('canonical_facts'):
if not facts_missing_logged: if not facts_missing_logged:

View File

@@ -12,7 +12,6 @@ from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly
from awx.main.tasks.helpers import is_run_threshold_reached from awx.main.tasks.helpers import is_run_threshold_reached
from awx.conf.license import get_license from awx.conf.license import get_license
from ansible_base.lib.utils.db import advisory_lock from ansible_base.lib.utils.db import advisory_lock
from awx.main.utils.db import bulk_update_sorted_by_id
logger = logging.getLogger('awx.main.tasks.host_metrics') logger = logging.getLogger('awx.main.tasks.host_metrics')
@@ -147,9 +146,8 @@ class HostMetricSummaryMonthlyTask:
month = month + relativedelta(months=1) month = month + relativedelta(months=1)
# Create/Update stats # Create/Update stats
HostMetricSummaryMonthly.objects.bulk_create(self.records_to_create) HostMetricSummaryMonthly.objects.bulk_create(self.records_to_create, batch_size=1000)
HostMetricSummaryMonthly.objects.bulk_update(self.records_to_update, ['license_consumed', 'hosts_added', 'hosts_deleted'], batch_size=1000)
bulk_update_sorted_by_id(HostMetricSummaryMonthly, self.records_to_update, ['license_consumed', 'hosts_added', 'hosts_deleted'])
# Set timestamp of last run # Set timestamp of last run
settings.HOST_METRIC_SUMMARY_TASK_LAST_TS = now() settings.HOST_METRIC_SUMMARY_TASK_LAST_TS = now()

View File

@@ -21,6 +21,7 @@ from django.conf import settings
# Shared code for the AWX platform # Shared code for the AWX platform
from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT, get_incontainer_path from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT, get_incontainer_path
# Runner # Runner
import ansible_runner import ansible_runner
@@ -28,6 +29,7 @@ import ansible_runner
import git import git
from gitdb.exc import BadName as BadGitName from gitdb.exc import BadName as BadGitName
# AWX # AWX
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
@@ -63,12 +65,11 @@ from awx.main.tasks.callback import (
RunnerCallbackForProjectUpdate, RunnerCallbackForProjectUpdate,
RunnerCallbackForSystemJob, RunnerCallbackForSystemJob,
) )
from awx.main.tasks.policy import evaluate_policy
from awx.main.tasks.signals import with_signal_handling, signal_callback from awx.main.tasks.signals import with_signal_handling, signal_callback
from awx.main.tasks.receptor import AWXReceptorJob from awx.main.tasks.receptor import AWXReceptorJob
from awx.main.tasks.facts import start_fact_cache, finish_fact_cache from awx.main.tasks.facts import start_fact_cache, finish_fact_cache
from awx.main.tasks.system import update_smart_memberships_for_inventory, update_inventory_computed_fields, events_processed_hook from awx.main.tasks.system import update_smart_memberships_for_inventory, update_inventory_computed_fields, events_processed_hook
from awx.main.exceptions import AwxTaskError, PolicyEvaluationError, PostRunError, ReceptorNodeNotFound from awx.main.exceptions import AwxTaskError, PostRunError, ReceptorNodeNotFound
from awx.main.utils.ansible import read_ansible_config from awx.main.utils.ansible import read_ansible_config
from awx.main.utils.safe_yaml import safe_dump, sanitize_jinja from awx.main.utils.safe_yaml import safe_dump, sanitize_jinja
from awx.main.utils.common import ( from awx.main.utils.common import (
@@ -487,7 +488,6 @@ class BaseTask(object):
self.instance.send_notification_templates("running") self.instance.send_notification_templates("running")
private_data_dir = self.build_private_data_dir(self.instance) private_data_dir = self.build_private_data_dir(self.instance)
self.pre_run_hook(self.instance, private_data_dir) self.pre_run_hook(self.instance, private_data_dir)
evaluate_policy(self.instance)
self.build_project_dir(self.instance, private_data_dir) self.build_project_dir(self.instance, private_data_dir)
self.instance.log_lifecycle("preparing_playbook") self.instance.log_lifecycle("preparing_playbook")
if self.instance.cancel_flag or signal_callback(): if self.instance.cancel_flag or signal_callback():
@@ -522,13 +522,9 @@ class BaseTask(object):
credentials = self.build_credentials_list(self.instance) credentials = self.build_credentials_list(self.instance)
container_root = None
if settings.IS_K8S and isinstance(self.instance, ProjectUpdate):
container_root = private_data_dir
for credential in credentials: for credential in credentials:
if credential: if credential:
credential.credential_type.inject_credential(credential, env, self.safe_cred_env, args, private_data_dir, container_root=container_root) credential.credential_type.inject_credential(credential, env, self.safe_cred_env, args, private_data_dir)
self.runner_callback.safe_env.update(self.safe_cred_env) self.runner_callback.safe_env.update(self.safe_cred_env)
@@ -619,8 +615,6 @@ class BaseTask(object):
elif cancel_flag_value is False: elif cancel_flag_value is False:
self.runner_callback.delay_update(skip_if_already_set=True, job_explanation="The running ansible process received a shutdown signal.") self.runner_callback.delay_update(skip_if_already_set=True, job_explanation="The running ansible process received a shutdown signal.")
status = 'failed' status = 'failed'
except PolicyEvaluationError as exc:
self.runner_callback.delay_update(job_explanation=str(exc), result_traceback=str(exc))
except ReceptorNodeNotFound as exc: except ReceptorNodeNotFound as exc:
self.runner_callback.delay_update(job_explanation=str(exc)) self.runner_callback.delay_update(job_explanation=str(exc))
except Exception: except Exception:
@@ -923,6 +917,7 @@ class RunJob(SourceControlMixin, BaseTask):
env['ANSIBLE_NET_AUTH_PASS'] = network_cred.get_input('authorize_password', default='') env['ANSIBLE_NET_AUTH_PASS'] = network_cred.get_input('authorize_password', default='')
path_vars = [ path_vars = [
('ANSIBLE_COLLECTIONS_PATHS', 'collections_paths', 'requirements_collections', '~/.ansible/collections:/usr/share/ansible/collections'),
('ANSIBLE_ROLES_PATH', 'roles_path', 'requirements_roles', '~/.ansible/roles:/usr/share/ansible/roles:/etc/ansible/roles'), ('ANSIBLE_ROLES_PATH', 'roles_path', 'requirements_roles', '~/.ansible/roles:/usr/share/ansible/roles:/etc/ansible/roles'),
('ANSIBLE_COLLECTIONS_PATH', 'collections_path', 'requirements_collections', '~/.ansible/collections:/usr/share/ansible/collections'), ('ANSIBLE_COLLECTIONS_PATH', 'collections_path', 'requirements_collections', '~/.ansible/collections:/usr/share/ansible/collections'),
] ]
@@ -1093,8 +1088,8 @@ class RunJob(SourceControlMixin, BaseTask):
# where ansible expects to find it # where ansible expects to find it
if self.should_use_fact_cache(): if self.should_use_fact_cache():
job.log_lifecycle("start_job_fact_cache") job.log_lifecycle("start_job_fact_cache")
self.hosts_with_facts_cached = start_fact_cache( self.facts_write_time = start_fact_cache(
job.get_hosts_for_fact_cache(), artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)), inventory_id=job.inventory_id job.get_hosts_for_fact_cache(), os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'), inventory_id=job.inventory_id
) )
def build_project_dir(self, job, private_data_dir): def build_project_dir(self, job, private_data_dir):
@@ -1104,7 +1099,7 @@ class RunJob(SourceControlMixin, BaseTask):
super(RunJob, self).post_run_hook(job, status) super(RunJob, self).post_run_hook(job, status)
job.refresh_from_db(fields=['job_env']) job.refresh_from_db(fields=['job_env'])
private_data_dir = job.job_env.get('AWX_PRIVATE_DATA_DIR') private_data_dir = job.job_env.get('AWX_PRIVATE_DATA_DIR')
if not private_data_dir: if (not private_data_dir) or (not hasattr(self, 'facts_write_time')):
# If there's no private data dir, that means we didn't get into the # If there's no private data dir, that means we didn't get into the
# actual `run()` call; this _usually_ means something failed in # actual `run()` call; this _usually_ means something failed in
# the pre_run_hook method # the pre_run_hook method
@@ -1112,7 +1107,9 @@ class RunJob(SourceControlMixin, BaseTask):
if self.should_use_fact_cache() and self.runner_callback.artifacts_processed: if self.should_use_fact_cache() and self.runner_callback.artifacts_processed:
job.log_lifecycle("finish_job_fact_cache") job.log_lifecycle("finish_job_fact_cache")
finish_fact_cache( finish_fact_cache(
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)), job.get_hosts_for_fact_cache(),
os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'),
facts_write_time=self.facts_write_time,
job_id=job.id, job_id=job.id,
inventory_id=job.inventory_id, inventory_id=job.inventory_id,
) )
@@ -1523,7 +1520,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
raise NotImplementedError('Cannot update file sources through the task system.') raise NotImplementedError('Cannot update file sources through the task system.')
if inventory_update.source == 'scm' and inventory_update.source_project_update: if inventory_update.source == 'scm' and inventory_update.source_project_update:
env_key = 'ANSIBLE_COLLECTIONS_PATH' env_key = 'ANSIBLE_COLLECTIONS_PATHS'
config_setting = 'collections_paths' config_setting = 'collections_paths'
folder = 'requirements_collections' folder = 'requirements_collections'
default = '~/.ansible/collections:/usr/share/ansible/collections' default = '~/.ansible/collections:/usr/share/ansible/collections'
@@ -1541,12 +1538,12 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
paths = [config_values[config_setting]] + paths paths = [config_values[config_setting]] + paths
paths = [os.path.join(CONTAINER_ROOT, folder)] + paths paths = [os.path.join(CONTAINER_ROOT, folder)] + paths
env[env_key] = os.pathsep.join(paths) env[env_key] = os.pathsep.join(paths)
if 'ANSIBLE_COLLECTIONS_PATH' in env: if 'ANSIBLE_COLLECTIONS_PATHS' in env:
paths = env['ANSIBLE_COLLECTIONS_PATH'].split(':') paths = env['ANSIBLE_COLLECTIONS_PATHS'].split(':')
else: else:
paths = ['~/.ansible/collections', '/usr/share/ansible/collections'] paths = ['~/.ansible/collections', '/usr/share/ansible/collections']
paths.append('/usr/share/automation-controller/collections') paths.append('/usr/share/automation-controller/collections')
env['ANSIBLE_COLLECTIONS_PATH'] = os.pathsep.join(paths) env['ANSIBLE_COLLECTIONS_PATHS'] = os.pathsep.join(paths)
return env return env
@@ -1578,7 +1575,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
# Include any facts from input inventories so they can be used in filters # Include any facts from input inventories so they can be used in filters
start_fact_cache( start_fact_cache(
input_inventory.hosts.only(*HOST_FACTS_FIELDS), input_inventory.hosts.only(*HOST_FACTS_FIELDS),
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(inventory_update.id)), os.path.join(private_data_dir, 'artifacts', str(inventory_update.id), 'fact_cache'),
inventory_id=input_inventory.id, inventory_id=input_inventory.id,
) )

View File

@@ -1,462 +0,0 @@
import json
import tempfile
import contextlib
from pprint import pformat
from typing import Optional, Union
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from flags.state import flag_enabled
from opa_client import OpaClient
from opa_client.base import BaseClient
from requests import HTTPError
from rest_framework import serializers
from rest_framework import fields
from awx.main import models
from awx.main.exceptions import PolicyEvaluationError
# Monkey patching opa_client.base.BaseClient to fix retries and timeout settings
_original_opa_base_client_init = BaseClient.__init__
def _opa_base_client_init_fix(
self,
host: str = "localhost",
port: int = 8181,
version: str = "v1",
ssl: bool = False,
cert: Optional[Union[str, tuple]] = None,
headers: Optional[dict] = None,
retries: int = 2,
timeout: float = 1.5,
):
_original_opa_base_client_init(self, host, port, version, ssl, cert, headers)
self.retries = retries
self.timeout = timeout
BaseClient.__init__ = _opa_base_client_init_fix
class _TeamSerializer(serializers.ModelSerializer):
class Meta:
model = models.Team
fields = ('id', 'name')
class _UserSerializer(serializers.ModelSerializer):
teams = serializers.SerializerMethodField()
class Meta:
model = models.User
fields = ('id', 'username', 'is_superuser', 'teams')
def get_teams(self, user: models.User):
teams = models.Team.access_qs(user, 'member')
return _TeamSerializer(many=True).to_representation(teams)
class _ExecutionEnvironmentSerializer(serializers.ModelSerializer):
class Meta:
model = models.ExecutionEnvironment
fields = (
'id',
'name',
'image',
'pull',
)
class _InstanceGroupSerializer(serializers.ModelSerializer):
class Meta:
model = models.InstanceGroup
fields = (
'id',
'name',
'capacity',
'jobs_running',
'jobs_total',
'max_concurrent_jobs',
'max_forks',
)
class _InventorySourceSerializer(serializers.ModelSerializer):
class Meta:
model = models.InventorySource
fields = ('id', 'name', 'source', 'status')
class _InventorySerializer(serializers.ModelSerializer):
inventory_sources = _InventorySourceSerializer(many=True)
class Meta:
model = models.Inventory
fields = (
'id',
'name',
'description',
'kind',
'total_hosts',
'total_groups',
'has_inventory_sources',
'total_inventory_sources',
'has_active_failures',
'hosts_with_active_failures',
'inventory_sources',
)
class _JobTemplateSerializer(serializers.ModelSerializer):
class Meta:
model = models.JobTemplate
fields = (
'id',
'name',
'job_type',
)
class _WorkflowJobTemplateSerializer(serializers.ModelSerializer):
class Meta:
model = models.WorkflowJobTemplate
fields = (
'id',
'name',
'job_type',
)
class _WorkflowJobSerializer(serializers.ModelSerializer):
class Meta:
model = models.WorkflowJob
fields = (
'id',
'name',
)
class _OrganizationSerializer(serializers.ModelSerializer):
class Meta:
model = models.Organization
fields = (
'id',
'name',
)
class _ProjectSerializer(serializers.ModelSerializer):
class Meta:
model = models.Project
fields = (
'id',
'name',
'status',
'scm_type',
'scm_url',
'scm_branch',
'scm_refspec',
'scm_clean',
'scm_track_submodules',
'scm_delete_on_update',
)
class _CredentialSerializer(serializers.ModelSerializer):
organization = _OrganizationSerializer()
class Meta:
model = models.Credential
fields = (
'id',
'name',
'description',
'organization',
'credential_type',
'managed',
'kind',
'cloud',
'kubernetes',
)
class _LabelSerializer(serializers.ModelSerializer):
organization = _OrganizationSerializer()
class Meta:
model = models.Label
fields = ('id', 'name', 'organization')
class JobSerializer(serializers.ModelSerializer):
created_by = _UserSerializer()
credentials = _CredentialSerializer(many=True)
execution_environment = _ExecutionEnvironmentSerializer()
instance_group = _InstanceGroupSerializer()
inventory = _InventorySerializer()
job_template = _JobTemplateSerializer()
labels = _LabelSerializer(many=True)
organization = _OrganizationSerializer()
project = _ProjectSerializer()
extra_vars = fields.SerializerMethodField()
hosts_count = fields.SerializerMethodField()
workflow_job = fields.SerializerMethodField()
workflow_job_template = fields.SerializerMethodField()
class Meta:
model = models.Job
fields = (
'id',
'name',
'created',
'created_by',
'credentials',
'execution_environment',
'extra_vars',
'forks',
'hosts_count',
'instance_group',
'inventory',
'job_template',
'job_type',
'job_type_name',
'labels',
'launch_type',
'limit',
'launched_by',
'organization',
'playbook',
'project',
'scm_branch',
'scm_revision',
'workflow_job',
'workflow_job_template',
)
def get_extra_vars(self, obj: models.Job):
return json.loads(obj.display_extra_vars())
def get_hosts_count(self, obj: models.Job):
return obj.hosts.count()
def get_workflow_job(self, obj: models.Job):
workflow_job: models.WorkflowJob = obj.get_workflow_job()
if workflow_job is None:
return None
return _WorkflowJobSerializer().to_representation(workflow_job)
def get_workflow_job_template(self, obj: models.Job):
workflow_job: models.WorkflowJob = obj.get_workflow_job()
if workflow_job is None:
return None
workflow_job_template: models.WorkflowJobTemplate = workflow_job.workflow_job_template
if workflow_job_template is None:
return None
return _WorkflowJobTemplateSerializer().to_representation(workflow_job_template)
class OPAResultSerializer(serializers.Serializer):
allowed = fields.BooleanField(required=True)
violations = fields.ListField(child=fields.CharField())
class OPA_AUTH_TYPES:
NONE = 'None'
TOKEN = 'Token'
CERTIFICATE = 'Certificate'
@contextlib.contextmanager
def opa_cert_file():
"""
Context manager that creates temporary certificate files for OPA authentication.
For mTLS (mutual TLS), we need:
- Client certificate and key for client authentication
- CA certificate (optional) for server verification
Returns:
tuple: (client_cert_path, verify_path)
- client_cert_path: Path to client cert file or None if not using client cert
- verify_path: Path to CA cert file, True to use system CA store, or False for no verification
"""
client_cert_temp = None
ca_temp = None
try:
# Case 1: Full mTLS with client cert and optional CA cert
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.CERTIFICATE:
# Create client certificate file (required for mTLS)
client_cert_temp = tempfile.NamedTemporaryFile(delete=True, mode='w', suffix=".pem")
client_cert_temp.write(settings.OPA_AUTH_CLIENT_CERT)
client_cert_temp.write("\n")
client_cert_temp.write(settings.OPA_AUTH_CLIENT_KEY)
client_cert_temp.write("\n")
client_cert_temp.flush()
# If CA cert is provided, use it for server verification
# Otherwise, use system CA store (True)
if settings.OPA_AUTH_CA_CERT:
ca_temp = tempfile.NamedTemporaryFile(delete=True, mode='w', suffix=".pem")
ca_temp.write(settings.OPA_AUTH_CA_CERT)
ca_temp.write("\n")
ca_temp.flush()
verify_path = ca_temp.name
else:
verify_path = True # Use system CA store
yield (client_cert_temp.name, verify_path)
# Case 2: TLS with only server verification (no client cert)
elif settings.OPA_SSL:
# If CA cert is provided, use it for server verification
# Otherwise, use system CA store (True)
if settings.OPA_AUTH_CA_CERT:
ca_temp = tempfile.NamedTemporaryFile(delete=True, mode='w', suffix=".pem")
ca_temp.write(settings.OPA_AUTH_CA_CERT)
ca_temp.write("\n")
ca_temp.flush()
verify_path = ca_temp.name
else:
verify_path = True # Use system CA store
yield (None, verify_path)
# Case 3: No TLS
else:
yield (None, False)
finally:
# Clean up temporary files
if client_cert_temp:
client_cert_temp.close()
if ca_temp:
ca_temp.close()
@contextlib.contextmanager
def opa_client(headers=None):
with opa_cert_file() as cert_files:
cert, verify = cert_files
with OpaClient(
host=settings.OPA_HOST,
port=settings.OPA_PORT,
headers=headers,
ssl=settings.OPA_SSL,
cert=cert,
timeout=settings.OPA_REQUEST_TIMEOUT,
retries=settings.OPA_REQUEST_RETRIES,
) as client:
# Workaround for https://github.com/Turall/OPA-python-client/issues/32
# by directly setting cert and verify on requests.session
client._session.cert = cert
client._session.verify = verify
yield client
def evaluate_policy(instance):
# Policy evaluation for Policy as Code feature
if not flag_enabled("FEATURE_POLICY_AS_CODE_ENABLED"):
return
if not settings.OPA_HOST:
return
if not isinstance(instance, models.Job):
return
instance.log_lifecycle("evaluate_policy")
input_data = JobSerializer(instance=instance).data
headers = settings.OPA_AUTH_CUSTOM_HEADERS
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.TOKEN:
headers.update({'Authorization': 'Bearer {}'.format(settings.OPA_AUTH_TOKEN)})
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.CERTIFICATE and not settings.OPA_SSL:
raise PolicyEvaluationError(_('OPA_AUTH_TYPE=Certificate requires OPA_SSL to be enabled.'))
cert_settings_missing = []
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.CERTIFICATE:
if not settings.OPA_AUTH_CLIENT_CERT:
cert_settings_missing += ['OPA_AUTH_CLIENT_CERT']
if not settings.OPA_AUTH_CLIENT_KEY:
cert_settings_missing += ['OPA_AUTH_CLIENT_KEY']
if not settings.OPA_AUTH_CA_CERT:
cert_settings_missing += ['OPA_AUTH_CA_CERT']
if cert_settings_missing:
raise PolicyEvaluationError(_('Following certificate settings are missing for OPA_AUTH_TYPE=Certificate: {}').format(cert_settings_missing))
query_paths = [
('Organization', instance.organization.opa_query_path),
('Inventory', instance.inventory.opa_query_path),
('Job template', instance.job_template.opa_query_path),
]
violations = dict()
errors = dict()
try:
with opa_client(headers=headers) as client:
for path_type, query_path in query_paths:
response = dict()
try:
if not query_path:
continue
response = client.query_rule(input_data=input_data, package_path=query_path)
except HTTPError as e:
message = _('Call to OPA failed. Exception: {}').format(e)
try:
error_data = e.response.json()
except ValueError:
errors[path_type] = message
continue
error_code = error_data.get("code")
error_message = error_data.get("message")
if error_code or error_message:
message = _('Call to OPA failed. Code: {}, Message: {}').format(error_code, error_message)
errors[path_type] = message
continue
except Exception as e:
errors[path_type] = _('Call to OPA failed. Exception: {}').format(e)
continue
result = response.get('result')
if result is None:
errors[path_type] = _('Call to OPA did not return a "result" property. The path refers to an undefined document.')
continue
result_serializer = OPAResultSerializer(data=result)
if not result_serializer.is_valid():
errors[path_type] = _('OPA policy returned invalid result.')
continue
result_data = result_serializer.validated_data
if not result_data.get("allowed") and (result_violations := result_data.get("violations")):
violations[path_type] = result_violations
format_results = dict()
if any(errors[e] for e in errors):
format_results["Errors"] = errors
if any(violations[v] for v in violations):
format_results["Violations"] = violations
if violations or errors:
raise PolicyEvaluationError(pformat(format_results, width=80))
except Exception as e:
raise PolicyEvaluationError(_('This job cannot be executed due to a policy violation or error. See the following details:\n{}').format(e))

View File

@@ -1,7 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- meta: clear_facts

View File

@@ -1,17 +0,0 @@
---
- hosts: all
vars:
extra_value: ""
gather_facts: false
connection: local
tasks:
- name: set a custom fact
set_fact:
foo: "bar{{ extra_value }}"
bar:
a:
b:
- "c"
- "d"
cacheable: true

View File

@@ -1,9 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
vars:
msg: 'hello'
tasks:
- debug: var=msg

View File

@@ -1,3 +0,0 @@
[all:vars]
a=value_a
b=value_b

View File

@@ -1,17 +0,0 @@
import time
import logging
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task
logger = logging.getLogger(__name__)
@task(queue=get_task_queuename)
def sleep_task(seconds=10, log=False):
if log:
logger.info('starting sleep_task')
time.sleep(seconds)
if log:
logger.info('finished sleep_task')

View File

@@ -87,8 +87,8 @@ def mock_analytic_post():
{ {
'REDHAT_USERNAME': 'redhat_user', 'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR 'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': '', 'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '', 'SUBSCRIPTIONS_PASSWORD': '',
}, },
True, True,
('redhat_user', 'redhat_pass'), ('redhat_user', 'redhat_pass'),
@@ -98,8 +98,8 @@ def mock_analytic_post():
{ {
'REDHAT_USERNAME': None, 'REDHAT_USERNAME': None,
'REDHAT_PASSWORD': None, 'REDHAT_PASSWORD': None,
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user', 'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR 'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
}, },
True, True,
('subs_user', 'subs_pass'), ('subs_user', 'subs_pass'),
@@ -109,8 +109,8 @@ def mock_analytic_post():
{ {
'REDHAT_USERNAME': '', 'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '', 'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user', 'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR 'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
}, },
True, True,
('subs_user', 'subs_pass'), ('subs_user', 'subs_pass'),
@@ -120,8 +120,8 @@ def mock_analytic_post():
{ {
'REDHAT_USERNAME': '', 'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '', 'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': '', 'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '', 'SUBSCRIPTIONS_PASSWORD': '',
}, },
False, False,
None, # No request should be made None, # No request should be made
@@ -131,8 +131,8 @@ def mock_analytic_post():
{ {
'REDHAT_USERNAME': '', 'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR 'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user', 'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': '', 'SUBSCRIPTIONS_PASSWORD': '',
}, },
False, False,
None, # Invalid, no request should be made None, # Invalid, no request should be made

View File

@@ -97,8 +97,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True, 'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': 'redhat_user', 'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR 'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': '', 'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '', 'SUBSCRIPTIONS_PASSWORD': '',
}, },
('redhat_user', 'redhat_pass'), ('redhat_user', 'redhat_pass'),
None, None,
@@ -109,8 +109,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True, 'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '', 'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '', 'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user', 'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR 'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
}, },
('subs_user', 'subs_pass'), ('subs_user', 'subs_pass'),
None, None,
@@ -121,8 +121,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True, 'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '', 'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '', 'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': '', 'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '', 'SUBSCRIPTIONS_PASSWORD': '',
}, },
None, None,
ERROR_MISSING_USER, ERROR_MISSING_USER,
@@ -133,8 +133,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True, 'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': 'redhat_user', 'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR 'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user', 'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR 'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
}, },
('redhat_user', 'redhat_pass'), ('redhat_user', 'redhat_pass'),
None, None,
@@ -145,8 +145,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True, 'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '', 'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '', 'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user', # NOSONAR 'SUBSCRIPTIONS_USERNAME': 'subs_user', # NOSONAR
'SUBSCRIPTIONS_CLIENT_SECRET': '', 'SUBSCRIPTIONS_PASSWORD': '',
}, },
None, None,
ERROR_MISSING_PASSWORD, ERROR_MISSING_PASSWORD,
@@ -155,36 +155,26 @@ class TestAnalyticsGenericView:
) )
@pytest.mark.django_db @pytest.mark.django_db
def test__send_to_analytics_credentials(self, settings_map, expected_auth, expected_error_keyword): def test__send_to_analytics_credentials(self, settings_map, expected_auth, expected_error_keyword):
"""
Test _send_to_analytics with various combinations of credentials.
"""
with override_settings(**settings_map): with override_settings(**settings_map):
request = RequestFactory().post('/some/path') request = RequestFactory().post('/some/path')
view = AnalyticsGenericView() view = AnalyticsGenericView()
if expected_auth: if expected_auth:
with mock.patch('awx.api.views.analytics.OIDCClient') as mock_oidc_client: with mock.patch('requests.request') as mock_request:
# Configure the mock OIDCClient instance and its make_request method mock_request.return_value = mock.Mock(status_code=200)
mock_client_instance = mock.Mock()
mock_oidc_client.return_value = mock_client_instance
mock_client_instance.make_request.return_value = mock.Mock(status_code=200)
analytic_url = view._get_analytics_url(request.path) analytic_url = view._get_analytics_url(request.path)
response = view._send_to_analytics(request, 'POST') response = view._send_to_analytics(request, 'POST')
# Assertions # Assertions
# Assert OIDCClient instantiation mock_request.assert_called_once_with(
expected_client_id, expected_client_secret = expected_auth
mock_oidc_client.assert_called_once_with(expected_client_id, expected_client_secret)
# Assert make_request call
mock_client_instance.make_request.assert_called_once_with(
'POST', 'POST',
analytic_url, analytic_url,
headers=mock.ANY, auth=expected_auth,
verify=mock.ANY, verify=mock.ANY,
params=mock.ANY, headers=mock.ANY,
json=mock.ANY, json=mock.ANY,
params=mock.ANY,
timeout=mock.ANY, timeout=mock.ANY,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -196,64 +186,3 @@ class TestAnalyticsGenericView:
# mock_error_response.assert_called_once_with(expected_error_keyword, remote=False) # mock_error_response.assert_called_once_with(expected_error_keyword, remote=False)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.data['error']['keyword'] == expected_error_keyword assert response.data['error']['keyword'] == expected_error_keyword
@pytest.mark.django_db
@pytest.mark.parametrize(
"settings_map, expected_auth",
[
# Test case 1: Username and password should be used for basic auth
(
{
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
},
('redhat_user', 'redhat_pass'),
),
# Test case 2: Client ID and secret should be used for basic auth
(
{
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR
},
None,
),
],
)
def test__send_to_analytics_fallback_to_basic_auth(self, settings_map, expected_auth):
"""
Test _send_to_analytics with basic auth fallback.
"""
with override_settings(**settings_map):
request = RequestFactory().post('/some/path')
view = AnalyticsGenericView()
with mock.patch('awx.api.views.analytics.OIDCClient') as mock_oidc_client, mock.patch(
'awx.api.views.analytics.AnalyticsGenericView._base_auth_request'
) as mock_base_auth_request:
# Configure the mock OIDCClient instance and its make_request method
mock_client_instance = mock.Mock()
mock_oidc_client.return_value = mock_client_instance
mock_client_instance.make_request.side_effect = requests.RequestException("Incorrect credentials")
analytic_url = view._get_analytics_url(request.path)
view._send_to_analytics(request, 'POST')
if expected_auth:
# assert mock_base_auth_request called with expected_auth
mock_base_auth_request.assert_called_once_with(
request,
'POST',
analytic_url,
expected_auth[0],
expected_auth[1],
mock.ANY,
)
else:
# assert mock_base_auth_request not called
mock_base_auth_request.assert_not_called()

View File

@@ -8,7 +8,6 @@ from django.core.exceptions import ValidationError
from awx.api.versioning import reverse from awx.api.versioning import reverse
from awx.main.models import InventorySource, Inventory, ActivityStream from awx.main.models import InventorySource, Inventory, ActivityStream
from awx.main.utils.inventory_vars import update_group_variables
@pytest.fixture @pytest.fixture
@@ -691,241 +690,3 @@ class TestConstructedInventory:
assert inv_r.data['url'] != const_r.data['url'] assert inv_r.data['url'] != const_r.data['url']
assert inv_r.data['related']['constructed_url'] == url_const assert inv_r.data['related']['constructed_url'] == url_const
assert const_r.data['related']['constructed_url'] == url_const assert const_r.data['related']['constructed_url'] == url_const
@pytest.mark.django_db
class TestInventoryAllVariables:
@staticmethod
def simulate_update_from_source(inv_src, variables_dict, overwrite_vars=True):
"""
Update `inventory` with variables `variables_dict` from source
`inv_src`.
"""
# Perform an update from source the same way it is done in
# `inventory_import.Command._update_inventory`.
new_vars = update_group_variables(
group_id=None, # `None` denotes the 'all' group (which doesn't have a pk).
newvars=variables_dict,
dbvars=inv_src.inventory.variables_dict,
invsrc_id=inv_src.id,
inventory_id=inv_src.inventory.id,
overwrite_vars=overwrite_vars,
)
inv_src.inventory.variables = json.dumps(new_vars)
inv_src.inventory.save(update_fields=["variables"])
return new_vars
def update_and_verify(self, inv_src, new_vars, expect=None, overwrite_vars=True, teststep=None):
"""
Helper: Update from source and verify the new inventory variables.
:param inv_src: An inventory source object with its inventory property
set to the inventory fixture of the called.
:param dict new_vars: The variables of the inventory source `inv_src`.
:param dict expect: (optional) The expected variables state of the
inventory after the update. If not set or None, expect `new_vars`.
:param bool overwrite_vars: The status of the inventory source option
'overwrite variables'. Default is `True`.
:raise AssertionError: If the inventory does not contain the expected
variables after the update.
"""
self.simulate_update_from_source(inv_src, new_vars, overwrite_vars=overwrite_vars)
if teststep is not None:
assert inv_src.inventory.variables_dict == (expect if expect is not None else new_vars), f"Test step {teststep}"
else:
assert inv_src.inventory.variables_dict == (expect if expect is not None else new_vars)
def test_set_variables_through_inventory_details_update(self, inventory, patch, admin_user):
"""
Set an inventory variable by changing the inventory details, simulating
a user edit.
"""
# a: x
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"a": "x"}
def test_variables_set_by_user_persist_update_from_src(self, inventory, inventory_source, patch, admin_user):
"""
Verify the special behavior that a variable which originates from a user
edit (instead of a source update), is not removed from the inventory
when a source update with overwrite_vars=True does not contain that
variable. This behavior is considered special because a variable which
originates from a source would actually be deleted.
In addition, verify that an existing variable which was set by a user
edit can be overwritten by a source update.
"""
# Set two variables via user edit.
patch(
url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}),
data={'variables': '{"a": "a_from_user", "b": "b_from_user"}'},
user=admin_user,
expect=200,
)
inventory.refresh_from_db()
assert inventory.variables_dict == {'a': 'a_from_user', 'b': 'b_from_user'}
# Update from a source which contains only one of the two variables from
# the previous update.
self.simulate_update_from_source(inventory_source, {'a': 'a_from_source'})
# Verify inventory variables.
assert inventory.variables_dict == {'a': 'a_from_source', 'b': 'b_from_user'}
def test_variables_set_through_src_get_removed_on_update_from_same_src(self, inventory, inventory_source, patch, admin_user):
"""
Verify that a variable which originates from a source update, is removed
from the inventory when a source update with overwrite_vars=True does
not contain that variable.
In addition, verify that an existing variable which was set by a user
edit can be overwritten by a source update.
"""
# Set two variables via update from source.
self.simulate_update_from_source(inventory_source, {'a': 'a_from_source', 'b': 'b_from_source'})
# Verify inventory variables.
assert inventory.variables_dict == {'a': 'a_from_source', 'b': 'b_from_source'}
# Update from the same source which now contains only one of the two
# variables from the previous update.
self.simulate_update_from_source(inventory_source, {'b': 'b_from_source'})
# Verify the variable has been deleted from the inventory.
assert inventory.variables_dict == {'b': 'b_from_source'}
def test_overwrite_variables_through_inventory_details_update(self, inventory, patch, admin_user):
"""
Set and update the inventory variables multiple times by changing the
inventory details via api, simulating user edits.
Any variables update by means of an inventory details update shall
overwright all existing inventory variables.
"""
# a: x
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"a": "x"}
# a: x2
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x2'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"a": "x2"}
# b: y
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'b: y'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"b": "y"}
def test_inventory_group_variables_internal_data(self, inventory, patch, admin_user):
"""
Basic verification of how variable updates are stored internally.
.. Warning::
This test verifies a specific implementation of the inventory
variables update business logic. It may deliver false negatives if
the implementation changes.
"""
# x: a
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x'}, user=admin_user, expect=200)
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'a': [[-1, 'x']]}
# b: y
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'b: y'}, user=admin_user, expect=200)
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'b': [[-1, 'y']]}
def test_update_then_user_change(self, inventory, patch, admin_user, inventory_source):
"""
1. Update inventory vars by means of an inventory source update.
2. Update inventory vars by editing the inventory details (aka a 'user
update'), thereby changing variables values and deleting variables
from the inventory.
.. Warning::
This test partly relies on a specific implementation of the
inventory variables update business logic. It may deliver false
negatives if the implementation changes.
"""
assert inventory_source.inventory_id == inventory.pk # sanity
# ---- Test step 1: Set variables by updating from an inventory source.
self.simulate_update_from_source(inventory_source, {'foo': 'foo_from_source', 'bar': 'bar_from_source'})
# Verify inventory variables.
assert inventory.variables_dict == {'foo': 'foo_from_source', 'bar': 'bar_from_source'}
# Verify internal storage of variables data. Note that this is
# implementation specific
assert inventory.inventory_group_variables.count() == 1
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'foo': [[inventory_source.id, 'foo_from_source']], 'bar': [[inventory_source.id, 'bar_from_source']]}
# ---- Test step 2: Change the variables by editing the inventory details.
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'foo: foo_from_user'}, user=admin_user, expect=200)
inventory.refresh_from_db()
# Verify that variable `foo` contains the new value, and that variable
# `bar` has been deleted from the inventory.
assert inventory.variables_dict == {"foo": "foo_from_user"}
# Verify internal storage of variables data. Note that this is
# implementation specific
inventory.inventory_group_variables.count() == 1
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'foo': [[-1, 'foo_from_user']]}
def test_monotonic_deletions(self, inventory, patch, admin_user):
"""
Verify the variables history logic for monotonic deletions.
Monotonic in this context means that the variables are deleted in the
reverse order of their creation.
1. Set inventory variable x: 0, expect INV={x: 0}
(The following steps use overwrite_variables=False)
2. Update from source A={x: 1}, expect INV={x: 1}
3. Update from source B={x: 2}, expect INV={x: 2}
4. Update from source B={}, expect INV={x: 1}
5. Update from source A={}, expect INV={x: 0}
"""
inv_src_a = InventorySource.objects.create(name="inv-src-A", inventory=inventory, source="ec2")
inv_src_b = InventorySource.objects.create(name="inv-src-B", inventory=inventory, source="ec2")
# Test step 1:
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'x: 0'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"x": 0}
# Test step 2: Source A overwrites value of var x
self.update_and_verify(inv_src_a, {"x": 1}, teststep=2)
# Test step 3: Source A overwrites value of var x
self.update_and_verify(inv_src_b, {"x": 2}, teststep=3)
# Test step 4: Value of var x from source A reappears
self.update_and_verify(inv_src_b, {}, expect={"x": 1}, teststep=4)
# Test step 5: Value of var x from initial user edit reappears
self.update_and_verify(inv_src_a, {}, expect={"x": 0}, teststep=5)
def test_interleaved_deletions(self, inventory, patch, admin_user, inventory_source):
"""
Verify the variables history logic for interleaved deletions.
Interleaved in this context means that the variables are deleted in a
different order than the sequence of their creation.
1. Set inventory variable x: 0, expect INV={x: 0}
2. Update from source A={x: 1}, expect INV={x: 1}
3. Update from source B={x: 2}, expect INV={x: 2}
4. Update from source C={x: 3}, expect INV={x: 3}
5. Update from source B={}, expect INV={x: 3}
6. Update from source C={}, expect INV={x: 1}
"""
inv_src_a = InventorySource.objects.create(name="inv-src-A", inventory=inventory, source="ec2")
inv_src_b = InventorySource.objects.create(name="inv-src-B", inventory=inventory, source="ec2")
inv_src_c = InventorySource.objects.create(name="inv-src-C", inventory=inventory, source="ec2")
# Test step 1. Set inventory variable x: 0
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'x: 0'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"x": 0}
# Test step 2: Source A overwrites value of var x
self.update_and_verify(inv_src_a, {"x": 1}, teststep=2)
# Test step 3: Source B overwrites value of var x
self.update_and_verify(inv_src_b, {"x": 2}, teststep=3)
# Test step 4: Source C overwrites value of var x
self.update_and_verify(inv_src_c, {"x": 3}, teststep=4)
# Test step 5: Value of var x from source C remains unchanged
self.update_and_verify(inv_src_b, {}, expect={"x": 3}, teststep=5)
# Test step 6: Value of var x from source A reappears, because the
# latest update from source B did not contain var x.
self.update_and_verify(inv_src_c, {}, expect={"x": 1}, teststep=6)

View File

@@ -210,39 +210,6 @@ def test_disallowed_http_update_methods(put, patch, post, inventory, project, ad
patch(url=reverse('api:job_detail', kwargs={'pk': job.pk}), data={}, user=admin_user, expect=405) patch(url=reverse('api:job_detail', kwargs={'pk': job.pk}), data={}, user=admin_user, expect=405)
@pytest.mark.django_db
@pytest.mark.parametrize(
"job_type",
[
'run',
'check',
],
)
def test_job_relaunch_with_job_type(post, inventory, project, machine_credential, admin_user, job_type):
# Create a job template
jt = JobTemplate.objects.create(name='testjt', inventory=inventory, project=project)
# Set initial job type
init_job_type = 'check' if job_type == 'run' else 'run'
# Create a job instance
job = jt.create_unified_job(_eager_fields={'job_type': init_job_type})
# Perform the POST request
url = reverse('api:job_relaunch', kwargs={'pk': job.pk})
r = post(url=url, data={'job_type': job_type}, user=admin_user, expect=201)
# Assert that the response status code is 201 (Created)
assert r.status_code == 201
# Retrieve the newly created job from the response
new_job_id = r.data.get('id')
new_job = Job.objects.get(id=new_job_id)
# Assert that the new job has the correct job type
assert new_job.job_type == job_type
class TestControllerNode: class TestControllerNode:
@pytest.fixture @pytest.fixture
def project_update(self, project): def project_update(self, project):

View File

@@ -56,175 +56,6 @@ def test_user_create(post, admin):
assert not response.data['is_system_auditor'] assert not response.data['is_system_auditor']
# Disable local password checks to ensure that any ValidationError originates from the Django validators.
@override_settings(
LOCAL_PASSWORD_MIN_LENGTH=1,
LOCAL_PASSWORD_MIN_DIGITS=0,
LOCAL_PASSWORD_MIN_UPPER=0,
LOCAL_PASSWORD_MIN_SPECIAL=0,
)
@pytest.mark.django_db
def test_user_create_with_django_password_validation_basic(post, admin):
"""Test if the Django password validators are applied correctly."""
with override_settings(
AUTH_PASSWORD_VALIDATORS=[
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
'OPTIONS': {
'min_length': 3,
},
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
],
):
# This user should fail the UserAttrSimilarity, MinLength and CommonPassword validators.
user_attrs = (
{
"password": "Password", # NOSONAR
"username": "Password",
"is_superuser": False,
},
)
print(f"Create user with invalid password {user_attrs=}")
response = post(reverse('api:user_list'), user_attrs, admin, middleware=SessionMiddleware(mock.Mock()))
assert response.status_code == 400
# This user should pass all Django validators.
user_attrs = {
"password": "r$TyKiOCb#ED", # NOSONAR
"username": "TestUser",
"is_superuser": False,
}
print(f"Create user with valid password {user_attrs=}")
response = post(reverse('api:user_list'), user_attrs, admin, middleware=SessionMiddleware(mock.Mock()))
assert response.status_code == 201
@pytest.mark.parametrize(
"user_attrs,validators,expected_status_code",
[
# Test password similarity with username.
(
{"password": "TestUser1", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator'},
],
400,
),
(
{"password": "abc", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator'},
],
201,
),
# Test password min length criterion.
(
{"password": "TooShort", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 'OPTIONS': {'min_length': 9}},
],
400,
),
(
{"password": "LongEnough", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 'OPTIONS': {'min_length': 9}},
],
201,
),
# Test password is too common criterion.
(
{"password": "Password", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator'},
],
400,
),
(
{"password": "aEArV$5Vkdw", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator'},
],
201,
),
# Test if password is only numeric.
(
{"password": "1234567890", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator'},
],
400,
),
(
{"password": "abc4567890", "username": "TestUser1", "is_superuser": False}, # NOSONAR
[
{'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator'},
],
201,
),
],
)
# Disable local password checks to ensure that any ValidationError originates from the Django validators.
@override_settings(
LOCAL_PASSWORD_MIN_LENGTH=1,
LOCAL_PASSWORD_MIN_DIGITS=0,
LOCAL_PASSWORD_MIN_UPPER=0,
LOCAL_PASSWORD_MIN_SPECIAL=0,
)
@pytest.mark.django_db
def test_user_create_with_django_password_validation_ext(post, delete, admin, user_attrs, validators, expected_status_code):
"""Test the functionality of the single Django password validators."""
#
default_parameters = {
# Default values for input parameters which are None.
"user_attrs": {
"password": "r$TyKiOCb#ED", # NOSONAR
"username": "DefaultUser",
"is_superuser": False,
},
"validators": [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
'OPTIONS': {
'min_length': 8,
},
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
],
}
user_attrs = user_attrs if user_attrs is not None else default_parameters["user_attrs"]
validators = validators if validators is not None else default_parameters["validators"]
with override_settings(AUTH_PASSWORD_VALIDATORS=validators):
response = post(reverse('api:user_list'), user_attrs, admin, middleware=SessionMiddleware(mock.Mock()))
assert response.status_code == expected_status_code
# Delete user if it was created succesfully.
if response.status_code == 201:
response = delete(reverse('api:user_detail', kwargs={'pk': response.data['id']}), admin, middleware=SessionMiddleware(mock.Mock()))
assert response.status_code == 204
else:
# Catch the unexpected behavior that sometimes the user is written
# into the database before the validation fails. This actually can
# happen if UserSerializer.validate instantiates User(**attrs)!
username = user_attrs['username']
assert not User.objects.filter(username=username)
@pytest.mark.django_db @pytest.mark.django_db
def test_fail_double_create_user(post, admin): def test_fail_double_create_user(post, admin):
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware(mock.Mock())) response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware(mock.Mock()))
@@ -251,10 +82,6 @@ def test_updating_own_password_refreshes_session(patch, admin):
Updating your own password should refresh the session id. Updating your own password should refresh the session id.
''' '''
with mock.patch('awx.api.serializers.update_session_auth_hash') as update_session_auth_hash: with mock.patch('awx.api.serializers.update_session_auth_hash') as update_session_auth_hash:
# Attention: If the Django password validator `CommonPasswordValidator`
# is active, this test case will fail because this validator raises on
# password 'newpassword'. Consider changing the hard-coded password to
# something uncommon.
patch(reverse('api:user_detail', kwargs={'pk': admin.pk}), {'password': 'newpassword'}, admin, middleware=SessionMiddleware(mock.Mock())) patch(reverse('api:user_detail', kwargs={'pk': admin.pk}), {'password': 'newpassword'}, admin, middleware=SessionMiddleware(mock.Mock()))
assert update_session_auth_hash.called assert update_session_auth_hash.called

View File

@@ -34,18 +34,40 @@ def test_wrapup_does_send_notifications(mocker):
mock.assert_called_once_with('succeeded') mock.assert_called_once_with('succeeded')
class FakeRedis:
def keys(self, *args, **kwargs):
return []
def set(self):
pass
def get(self):
return None
@classmethod
def from_url(cls, *args, **kwargs):
return cls()
def pipeline(self):
return self
class TestCallbackBrokerWorker(TransactionTestCase): class TestCallbackBrokerWorker(TransactionTestCase):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def turn_off_websockets_and_redis(self, fake_redis): def turn_off_websockets(self):
with mock.patch('awx.main.dispatch.worker.callback.emit_event_detail', lambda *a, **kw: None): with mock.patch('awx.main.dispatch.worker.callback.emit_event_detail', lambda *a, **kw: None):
yield yield
def get_worker(self):
with mock.patch('redis.Redis', new=FakeRedis): # turn off redis stuff
return CallbackBrokerWorker()
def event_create_kwargs(self): def event_create_kwargs(self):
inventory_update = InventoryUpdate.objects.create(source='file', inventory_source=InventorySource.objects.create(source='file')) inventory_update = InventoryUpdate.objects.create(source='file', inventory_source=InventorySource.objects.create(source='file'))
return dict(inventory_update=inventory_update, created=inventory_update.created) return dict(inventory_update=inventory_update, created=inventory_update.created)
def test_flush_with_valid_event(self): def test_flush_with_valid_event(self):
worker = CallbackBrokerWorker() worker = self.get_worker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())] events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events} worker.buff = {InventoryUpdateEvent: events}
worker.flush() worker.flush()
@@ -53,7 +75,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1 assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
def test_flush_with_invalid_event(self): def test_flush_with_invalid_event(self):
worker = CallbackBrokerWorker() worker = self.get_worker()
kwargs = self.event_create_kwargs() kwargs = self.event_create_kwargs()
events = [ events = [
InventoryUpdateEvent(uuid=str(uuid4()), stdout='good1', **kwargs), InventoryUpdateEvent(uuid=str(uuid4()), stdout='good1', **kwargs),
@@ -68,7 +90,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert worker.buff == {InventoryUpdateEvent: [events[1]]} assert worker.buff == {InventoryUpdateEvent: [events[1]]}
def test_duplicate_key_not_saved_twice(self): def test_duplicate_key_not_saved_twice(self):
worker = CallbackBrokerWorker() worker = self.get_worker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())] events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events.copy()} worker.buff = {InventoryUpdateEvent: events.copy()}
worker.flush() worker.flush()
@@ -82,7 +104,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert worker.buff.get(InventoryUpdateEvent, []) == [] assert worker.buff.get(InventoryUpdateEvent, []) == []
def test_give_up_on_bad_event(self): def test_give_up_on_bad_event(self):
worker = CallbackBrokerWorker() worker = self.get_worker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), counter=-2, **self.event_create_kwargs())] events = [InventoryUpdateEvent(uuid=str(uuid4()), counter=-2, **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events.copy()} worker.buff = {InventoryUpdateEvent: events.copy()}
@@ -95,7 +117,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 0 # sanity assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 0 # sanity
def test_flush_with_empty_buffer(self): def test_flush_with_empty_buffer(self):
worker = CallbackBrokerWorker() worker = self.get_worker()
worker.buff = {InventoryUpdateEvent: []} worker.buff = {InventoryUpdateEvent: []}
with mock.patch.object(InventoryUpdateEvent.objects, 'bulk_create') as flush_mock: with mock.patch.object(InventoryUpdateEvent.objects, 'bulk_create') as flush_mock:
worker.flush() worker.flush()
@@ -105,7 +127,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
# In postgres, text fields reject NUL character, 0x00 # In postgres, text fields reject NUL character, 0x00
# tests use sqlite3 which will not raise an error # tests use sqlite3 which will not raise an error
# but we can still test that it is sanitized before saving # but we can still test that it is sanitized before saving
worker = CallbackBrokerWorker() worker = self.get_worker()
kwargs = self.event_create_kwargs() kwargs = self.event_create_kwargs()
events = [InventoryUpdateEvent(uuid=str(uuid4()), stdout="\x00", **kwargs)] events = [InventoryUpdateEvent(uuid=str(uuid4()), stdout="\x00", **kwargs)]
assert "\x00" in events[0].stdout # sanity assert "\x00" in events[0].stdout # sanity

View File

@@ -63,33 +63,6 @@ def swagger_autogen(requests=__SWAGGER_REQUESTS__):
return requests return requests
class FakeRedis:
def keys(self, *args, **kwargs):
return []
def set(self):
pass
def get(self):
return None
@classmethod
def from_url(cls, *args, **kwargs):
return cls()
def pipeline(self):
return self
def ping(self):
return
@pytest.fixture
def fake_redis():
with mock.patch('redis.Redis', new=FakeRedis): # turn off redis stuff
yield
@pytest.fixture @pytest.fixture
def user(): def user():
def u(name, is_superuser=False): def u(name, is_superuser=False):

View File

@@ -106,17 +106,6 @@ def test_compat_role_naming(setup_managed_roles, job_template, rando, alice):
assert rd.created_by is None assert rd.created_by is None
@pytest.mark.django_db
def test_organization_admin_has_audit(setup_managed_roles):
"""This formalizes a behavior change from old to new RBAC system
Previously, the auditor_role did not list admin_role as a parent
this made various queries hard to deal with, requiring adding 2 conditions
The new system should explicitly list the auditor permission in org admin role"""
rd = RoleDefinition.objects.get(name='Organization Admin')
assert 'audit_organization' in rd.permissions.values_list('codename', flat=True)
@pytest.mark.django_db @pytest.mark.django_db
def test_organization_level_permissions(organization, inventory, setup_managed_roles): def test_organization_level_permissions(organization, inventory, setup_managed_roles):
u1 = User.objects.create(username='alice') u1 = User.objects.create(username='alice')

View File

@@ -1,56 +0,0 @@
import pytest
from awx.main.migrations._db_constraints import _rename_duplicates
from awx.main.models import JobTemplate
@pytest.mark.django_db
def test_rename_job_template_duplicates(organization, project):
ids = []
for i in range(5):
jt = JobTemplate.objects.create(name=f'jt-{i}', organization=organization, project=project)
ids.append(jt.id) # saved in order of creation
# Hack to first allow duplicate names of JT to test migration
JobTemplate.objects.filter(id__in=ids).update(org_unique=False)
# Set all JTs to the same name
JobTemplate.objects.filter(id__in=ids).update(name='same_name_for_test')
_rename_duplicates(JobTemplate)
first_jt = JobTemplate.objects.get(id=ids[0])
assert first_jt.name == 'same_name_for_test'
for i, pk in enumerate(ids):
if i == 0:
continue
jt = JobTemplate.objects.get(id=pk)
# Name should be set based on creation order
assert jt.name == f'same_name_for_test_dup{i}'
@pytest.mark.django_db
def test_rename_job_template_name_too_long(organization, project):
ids = []
for i in range(3):
jt = JobTemplate.objects.create(name=f'jt-{i}', organization=organization, project=project)
ids.append(jt.id) # saved in order of creation
JobTemplate.objects.filter(id__in=ids).update(org_unique=False)
chars = 512
# Set all JTs to the same reaaaaaaly long name
JobTemplate.objects.filter(id__in=ids).update(name='A' * chars)
_rename_duplicates(JobTemplate)
first_jt = JobTemplate.objects.get(id=ids[0])
assert first_jt.name == 'A' * chars
for i, pk in enumerate(ids):
if i == 0:
continue
jt = JobTemplate.objects.get(id=pk)
assert jt.name.endswith(f'dup{i}')
assert len(jt.name) <= 512

View File

@@ -135,9 +135,8 @@ class TestEvents:
self._create_job_event(ok=dict((hostname, len(hostname)) for hostname in self.hostnames)) self._create_job_event(ok=dict((hostname, len(hostname)) for hostname in self.hostnames))
# Soft delete 6 of the 12 host metrics, every even host like "Host 2" or "Host 4" # Soft delete 6 host metrics
for host_name in self.hostnames[::2]: for hm in HostMetric.objects.filter(id__in=[1, 3, 5, 7, 9, 11]):
hm = HostMetric.objects.get(hostname=host_name.lower())
hm.soft_delete() hm.soft_delete()
assert len(HostMetric.objects.filter(Q(deleted=False) & Q(deleted_counter=0) & Q(last_deleted__isnull=True))) == 6 assert len(HostMetric.objects.filter(Q(deleted=False) & Q(deleted_counter=0) & Q(last_deleted__isnull=True))) == 6
@@ -166,9 +165,7 @@ class TestEvents:
skipped=dict((hostname, len(hostname)) for hostname in self.hostnames[10:12]), skipped=dict((hostname, len(hostname)) for hostname in self.hostnames[10:12]),
) )
assert len(HostMetric.objects.filter(Q(deleted=False) & Q(deleted_counter=0) & Q(last_deleted__isnull=True))) == 6 assert len(HostMetric.objects.filter(Q(deleted=False) & Q(deleted_counter=0) & Q(last_deleted__isnull=True))) == 6
assert len(HostMetric.objects.filter(Q(deleted=False) & Q(deleted_counter=1) & Q(last_deleted__isnull=False))) == 6
# one of those 6 hosts is dark, so will not be counted
assert len(HostMetric.objects.filter(Q(deleted=False) & Q(deleted_counter=1) & Q(last_deleted__isnull=False))) == 5
def _generate_hosts(self, cnt, id_from=0): def _generate_hosts(self, cnt, id_from=0):
self.hostnames = [f'Host {i}' for i in range(id_from, id_from + cnt)] self.hostnames = [f'Host {i}' for i in range(id_from, id_from + cnt)]

View File

@@ -3,10 +3,6 @@ import pytest
# AWX # AWX
from awx.main.ha import is_ha_environment from awx.main.ha import is_ha_environment
from awx.main.models.ha import Instance from awx.main.models.ha import Instance
from awx.main.dispatch.pool import get_auto_max_workers
# Django
from django.test.utils import override_settings
@pytest.mark.django_db @pytest.mark.django_db
@@ -21,25 +17,3 @@ def test_db_localhost():
Instance.objects.create(hostname='foo', node_type='hybrid') Instance.objects.create(hostname='foo', node_type='hybrid')
Instance.objects.create(hostname='bar', node_type='execution') Instance.objects.create(hostname='bar', node_type='execution')
assert is_ha_environment() is False assert is_ha_environment() is False
@pytest.mark.django_db
@pytest.mark.parametrize(
'settings',
[
dict(SYSTEM_TASK_ABS_MEM='16Gi', SYSTEM_TASK_ABS_CPU='24', SYSTEM_TASK_FORKS_MEM=400, SYSTEM_TASK_FORKS_CPU=4),
dict(SYSTEM_TASK_ABS_MEM='124Gi', SYSTEM_TASK_ABS_CPU='2', SYSTEM_TASK_FORKS_MEM=None, SYSTEM_TASK_FORKS_CPU=None),
],
ids=['cpu_dominated', 'memory_dominated'],
)
def test_dispatcher_max_workers_reserve(settings, fake_redis):
"""This tests that the dispatcher max_workers matches instance capacity
Assumes capacity_adjustment is 1,
plus reserve worker count
"""
with override_settings(**settings):
i = Instance.objects.create(hostname='test-1', node_type='hybrid', capacity_adjustment=1.0)
i.local_health_check()
assert get_auto_max_workers() == i.capacity + 7, (i.cpu, i.memory, i.cpu_capacity, i.mem_capacity)

View File

@@ -1,7 +1,6 @@
import pytest import pytest
from awx.main.access import ( from awx.main.access import (
UnifiedJobAccess,
WorkflowJobTemplateAccess, WorkflowJobTemplateAccess,
WorkflowJobTemplateNodeAccess, WorkflowJobTemplateNodeAccess,
WorkflowJobAccess, WorkflowJobAccess,
@@ -246,30 +245,6 @@ class TestWorkflowJobAccess:
inventory.use_role.members.add(rando) inventory.use_role.members.add(rando)
assert WorkflowJobAccess(rando).can_start(workflow_job) assert WorkflowJobAccess(rando).can_start(workflow_job)
@pytest.mark.parametrize('org_role', ['admin_role', 'auditor_role'])
def test_workflow_job_org_audit_access(self, workflow_job_template, rando, org_role):
assert workflow_job_template.organization # sanity
workflow_job = workflow_job_template.create_unified_job()
assert workflow_job.organization # sanity
assert not UnifiedJobAccess(rando).can_read(workflow_job)
assert not WorkflowJobAccess(rando).can_read(workflow_job)
assert workflow_job not in WorkflowJobAccess(rando).filtered_queryset()
org = workflow_job.organization
role = getattr(org, org_role)
role.members.add(rando)
assert UnifiedJobAccess(rando).can_read(workflow_job)
assert WorkflowJobAccess(rando).can_read(workflow_job)
assert workflow_job in WorkflowJobAccess(rando).filtered_queryset()
# Organization-level permissions should persist after deleting the WFJT
workflow_job_template.delete()
assert UnifiedJobAccess(rando).can_read(workflow_job)
assert WorkflowJobAccess(rando).can_read(workflow_job)
assert workflow_job in WorkflowJobAccess(rando).filtered_queryset()
@pytest.mark.django_db @pytest.mark.django_db
class TestWFJTCopyAccess: class TestWFJTCopyAccess:

View File

@@ -393,7 +393,7 @@ def test_dependency_isolation(organization):
this should keep dependencies isolated""" this should keep dependencies isolated"""
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'): with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
updating_projects = [ updating_projects = [
Project.objects.create(name=f'iso-proj{i}', organization=organization, scm_url='https://foo.invalid', scm_type='git', scm_update_on_launch=True) Project.objects.create(name='iso-proj', organization=organization, scm_url='https://foo.invalid', scm_type='git', scm_update_on_launch=True)
for i in range(2) for i in range(2)
] ]

View File

@@ -1,5 +1,4 @@
import yaml import yaml
from functools import reduce
from unittest import mock from unittest import mock
import pytest import pytest
@@ -21,46 +20,6 @@ from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit
TEST_JQ = "{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {another_host_name: .direct_host_name}}" TEST_JQ = "{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {another_host_name: .direct_host_name}}"
class Query(dict):
def __init__(self, resolved_action: str, query_jq: dict):
self._resolved_action = resolved_action.split('.')
self._collection_ns, self._collection_name, self._module_name = self._resolved_action
super().__init__({self.resolve_key: {'query': query_jq}})
def get_fqcn(self):
return f'{self._collection_ns}.{self._collection_name}'
@property
def resolve_value(self):
return self[self.resolve_key]
@property
def resolve_key(self):
return f'{self.get_fqcn()}.{self._module_name}'
def resolve(self, module_name=None):
return {f'{self.get_fqcn()}.{module_name or self._module_name}': self.resolve_value}
def create_event_query(self, module_name=None):
if (module_name := module_name or self._module_name) == '*':
raise ValueError('Invalid module name *')
return self.create_event_queries([module_name])
def create_event_queries(self, module_names):
queries = {}
for name in module_names:
queries |= self.resolve(name)
return EventQuery.objects.create(
fqcn=self.get_fqcn(),
collection_version='1.0.1',
event_query=yaml.dump(queries, default_flow_style=False),
)
def create_registered_event(self, job, module_name):
job.job_events.create(event_data={'resolved_action': f'{self.get_fqcn()}.{module_name}', 'res': {'direct_host_name': 'foo_host', 'name': 'vm-foo'}})
@pytest.fixture @pytest.fixture
def bare_job(job_factory): def bare_job(job_factory):
job = job_factory() job = job_factory()
@@ -80,6 +39,11 @@ def job_with_counted_event(bare_job):
return bare_job return bare_job
def create_event_query(fqcn='demo.query'):
module_name = f'{fqcn}.example'
return EventQuery.objects.create(fqcn=fqcn, collection_version='1.0.1', event_query=yaml.dump({module_name: {'query': TEST_JQ}}, default_flow_style=False))
def create_audit_record(name, job, organization, created=now()): def create_audit_record(name, job, organization, created=now()):
record = IndirectManagedNodeAudit.objects.create(name=name, job=job, organization=organization) record = IndirectManagedNodeAudit.objects.create(name=name, job=job, organization=organization)
record.created = created record.created = created
@@ -90,7 +54,7 @@ def create_audit_record(name, job, organization, created=now()):
@pytest.fixture @pytest.fixture
def event_query(): def event_query():
"This is ordinarily created by the artifacts callback" "This is ordinarily created by the artifacts callback"
return Query('demo.query.example', TEST_JQ).create_event_query() return create_event_query()
@pytest.fixture @pytest.fixture
@@ -108,211 +72,105 @@ def new_audit_record(bare_job, organization):
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize( def test_build_with_no_results(bare_job):
'queries,expected_matches', # never filled in events, should do nothing
( assert build_indirect_host_data(bare_job, {}) == []
pytest.param(
[],
0,
id='no_results',
),
pytest.param(
[Query('demo.query.example', TEST_JQ)],
1,
id='fully_qualified',
),
pytest.param(
[Query('demo.query.*', TEST_JQ)],
1,
id='wildcard',
),
pytest.param(
[
Query('demo.query.*', TEST_JQ),
Query('demo.query.example', TEST_JQ),
],
1,
id='wildcard_and_fully_qualified',
),
pytest.param(
[
Query('demo.query.*', TEST_JQ),
Query('demo.query.example', {}),
],
0,
id='wildcard_and_fully_qualified',
),
pytest.param(
[
Query('demo.query.example', {}),
Query('demo.query.*', TEST_JQ),
],
0,
id='ordering_should_not_matter',
),
),
)
def test_build_indirect_host_data(job_with_counted_event, queries: Query, expected_matches: int):
data = build_indirect_host_data(job_with_counted_event, {k: v for d in queries for k, v in d.items()})
assert len(data) == expected_matches
@mock.patch('awx.main.tasks.host_indirect.logger.debug')
@pytest.mark.django_db
@pytest.mark.parametrize(
'task_name',
(
pytest.param(
'demo.query',
id='no_results',
),
pytest.param(
'demo',
id='no_results',
),
pytest.param(
'a.b.c.d',
id='no_results',
),
),
)
def test_build_indirect_host_data_malformed_module_name(mock_logger_debug, bare_job, task_name: str):
create_registered_event(bare_job, task_name)
assert build_indirect_host_data(bare_job, Query('demo.query.example', TEST_JQ)) == []
mock_logger_debug.assert_called_once_with(f"Malformed invocation module name '{task_name}'. Expected to be of the form 'a.b.c'")
@mock.patch('awx.main.tasks.host_indirect.logger.info')
@pytest.mark.django_db
@pytest.mark.parametrize(
'query',
(
pytest.param(
'demo.query',
id='no_results',
),
pytest.param(
'demo',
id='no_results',
),
pytest.param(
'a.b.c.d',
id='no_results',
),
),
)
def test_build_indirect_host_data_malformed_query(mock_logger_info, job_with_counted_event, query: str):
assert build_indirect_host_data(job_with_counted_event, {query: {'query': TEST_JQ}}) == []
mock_logger_info.assert_called_once_with(f"Skiping malformed query '{query}'. Expected to be of the form 'a.b.c'")
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize( def test_collect_an_event(job_with_counted_event):
'query', records = build_indirect_host_data(job_with_counted_event, {'demo.query.example': {'query': TEST_JQ}})
( assert len(records) == 1
pytest.param(
Query('demo.query.example', TEST_JQ),
id='fully_qualified',
),
pytest.param(
Query('demo.query.*', TEST_JQ),
id='wildcard',
),
),
)
def test_fetch_job_event_query(bare_job, query: Query):
query.create_event_query(module_name='example')
assert fetch_job_event_query(bare_job) == query.resolve('example')
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize( def test_fetch_job_event_query(bare_job, event_query):
'queries', assert fetch_job_event_query(bare_job) == {'demo.query.example': {'query': TEST_JQ}}
(
[
Query('demo.query.example', TEST_JQ),
Query('demo2.query.example', TEST_JQ),
],
[
Query('demo.query.*', TEST_JQ),
Query('demo2.query.example', TEST_JQ),
],
),
)
def test_fetch_multiple_job_event_query(bare_job, queries: list[Query]):
for q in queries:
q.create_event_query(module_name='example')
assert fetch_job_event_query(bare_job) == reduce(lambda acc, q: acc | q.resolve('example'), queries, {})
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize( def test_fetch_multiple_job_event_query(bare_job):
('state',), create_event_query(fqcn='demo.query')
( create_event_query(fqcn='demo2.query')
pytest.param( assert fetch_job_event_query(bare_job) == {'demo.query.example': {'query': TEST_JQ}, 'demo2.query.example': {'query': TEST_JQ}}
[
(
Query('demo.query.example', TEST_JQ), @pytest.mark.django_db
['example'], def test_save_indirect_host_entries(job_with_counted_event, event_query):
), assert job_with_counted_event.event_queries_processed is False
], save_indirect_host_entries(job_with_counted_event.id)
id='fully_qualified', job_with_counted_event.refresh_from_db()
), assert job_with_counted_event.event_queries_processed is True
pytest.param( assert IndirectManagedNodeAudit.objects.filter(job=job_with_counted_event).count() == 1
[ host_audit = IndirectManagedNodeAudit.objects.filter(job=job_with_counted_event).first()
( assert host_audit.count == 1
Query('demo.query.example', TEST_JQ), assert host_audit.canonical_facts == {'host_name': 'foo_host'}
['example'] * 3, assert host_audit.facts == {'another_host_name': 'foo_host'}
), assert host_audit.organization == job_with_counted_event.organization
], assert host_audit.name == 'vm-foo'
id='multiple_events_same_module_same_host',
),
pytest.param( @pytest.mark.django_db
[ def test_multiple_events_same_module_same_host(bare_job, event_query):
( "This tests that the count field gives correct answers"
Query('demo.query.example', TEST_JQ), create_registered_event(bare_job)
['example'], create_registered_event(bare_job)
), create_registered_event(bare_job)
(
Query('demo2.query.example', TEST_JQ),
['example'],
),
],
id='multiple_modules',
),
pytest.param(
[
(
Query('demo.query.*', TEST_JQ),
['example', 'example2'],
),
],
id='multiple_modules_same_collection',
),
),
)
def test_save_indirect_host_entries(bare_job, state):
all_task_names = []
for entry in state:
query, module_names = entry
all_task_names.extend([f'{query.get_fqcn()}.{module_name}' for module_name in module_names])
query.create_event_queries(module_names)
[query.create_registered_event(bare_job, n) for n in module_names]
save_indirect_host_entries(bare_job.id) save_indirect_host_entries(bare_job.id)
bare_job.refresh_from_db()
assert bare_job.event_queries_processed is True
assert IndirectManagedNodeAudit.objects.filter(job=bare_job).count() == 1 assert IndirectManagedNodeAudit.objects.filter(job=bare_job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=bare_job).first() host_audit = IndirectManagedNodeAudit.objects.filter(job=bare_job).first()
assert host_audit.count == len(all_task_names) assert host_audit.count == 3
assert host_audit.canonical_facts == {'host_name': 'foo_host'} assert host_audit.events == ['demo.query.example']
assert host_audit.facts == {'another_host_name': 'foo_host'}
assert host_audit.organization == bare_job.organization
assert host_audit.name == 'vm-foo' @pytest.mark.django_db
assert set(host_audit.events) == set(all_task_names) def test_multiple_registered_modules(bare_job):
"This tests that the events will list multiple modules if more than 1 module from different collections is registered and used"
create_registered_event(bare_job, task_name='demo.query.example')
create_registered_event(bare_job, task_name='demo2.query.example')
# These take the place of using the event_query fixture
create_event_query(fqcn='demo.query')
create_event_query(fqcn='demo2.query')
save_indirect_host_entries(bare_job.id)
assert IndirectManagedNodeAudit.objects.filter(job=bare_job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=bare_job).first()
assert host_audit.count == 2
assert set(host_audit.events) == {'demo.query.example', 'demo2.query.example'}
@pytest.mark.django_db
def test_multiple_registered_modules_same_collection(bare_job):
"This tests that the events will list multiple modules if more than 1 module in same collection is registered and used"
create_registered_event(bare_job, task_name='demo.query.example')
create_registered_event(bare_job, task_name='demo.query.example2')
# Takes place of event_query fixture, doing manually here
EventQuery.objects.create(
fqcn='demo.query',
collection_version='1.0.1',
event_query=yaml.dump(
{
'demo.query.example': {'query': TEST_JQ},
'demo.query.example2': {'query': TEST_JQ},
},
default_flow_style=False,
),
)
save_indirect_host_entries(bare_job.id)
assert IndirectManagedNodeAudit.objects.filter(job=bare_job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=bare_job).first()
assert host_audit.count == 2
assert set(host_audit.events) == {'demo.query.example', 'demo.query.example2'}
@pytest.mark.django_db @pytest.mark.django_db

View File

@@ -43,7 +43,7 @@ def test_job_template_copy(
c.save() c.save()
assert get(reverse('api:job_template_copy', kwargs={'pk': job_template_with_survey_passwords.pk}), alice, expect=200).data['can_copy'] is True assert get(reverse('api:job_template_copy', kwargs={'pk': job_template_with_survey_passwords.pk}), alice, expect=200).data['can_copy'] is True
jt_copy_pk_alice = post( jt_copy_pk_alice = post(
reverse('api:job_template_copy', kwargs={'pk': job_template_with_survey_passwords.pk}), {'name': 'new jt name alice'}, alice, expect=201 reverse('api:job_template_copy', kwargs={'pk': job_template_with_survey_passwords.pk}), {'name': 'new jt name'}, alice, expect=201
).data['id'] ).data['id']
jt_copy_admin = type(job_template_with_survey_passwords).objects.get(pk=jt_copy_pk) jt_copy_admin = type(job_template_with_survey_passwords).objects.get(pk=jt_copy_pk)
@@ -53,7 +53,7 @@ def test_job_template_copy(
assert jt_copy_alice.created_by == alice assert jt_copy_alice.created_by == alice
for jt_copy in (jt_copy_admin, jt_copy_alice): for jt_copy in (jt_copy_admin, jt_copy_alice):
assert jt_copy.name.startswith('new jt name') assert jt_copy.name == 'new jt name'
assert jt_copy.project == project assert jt_copy.project == project
assert jt_copy.inventory == inventory assert jt_copy.inventory == inventory
assert jt_copy.playbook == job_template_with_survey_passwords.playbook assert jt_copy.playbook == job_template_with_survey_passwords.playbook

View File

@@ -231,7 +231,7 @@ def test_inventory_update_injected_content(product_name, this_kind, inventory, f
len([True for k in content.keys() if k.endswith(inventory_filename)]) > 0 len([True for k in content.keys() if k.endswith(inventory_filename)]) > 0
), f"'{inventory_filename}' file not found in inventory update runtime files {content.keys()}" ), f"'{inventory_filename}' file not found in inventory update runtime files {content.keys()}"
env.pop('ANSIBLE_COLLECTIONS_PATH', None) env.pop('ANSIBLE_COLLECTIONS_PATHS', None) # collection paths not relevant to this test
base_dir = os.path.join(DATA, 'plugins') base_dir = os.path.join(DATA, 'plugins')
if not os.path.exists(base_dir): if not os.path.exists(base_dir):
os.mkdir(base_dir) os.mkdir(base_dir)

View File

@@ -19,7 +19,7 @@ from awx.main.models import (
ExecutionEnvironment, ExecutionEnvironment,
) )
from awx.main.tasks.system import cluster_node_heartbeat from awx.main.tasks.system import cluster_node_heartbeat
from awx.main.utils.db import bulk_update_sorted_by_id from awx.main.tasks.facts import update_hosts
from django.db import OperationalError from django.db import OperationalError
from django.test.utils import override_settings from django.test.utils import override_settings
@@ -39,13 +39,13 @@ def test_orphan_unified_job_creation(instance, inventory):
@pytest.mark.django_db @pytest.mark.django_db
@mock.patch('awx.main.tasks.system.inspect_execution_and_hop_nodes', lambda *args, **kwargs: None) @mock.patch('awx.main.tasks.system.inspect_execution_and_hop_nodes', lambda *args, **kwargs: None)
@mock.patch('awx.main.models.ha.get_cpu_effective_capacity', lambda cpu, is_control_node: 8) @mock.patch('awx.main.models.ha.get_cpu_effective_capacity', lambda cpu, is_control_node: 8)
@mock.patch('awx.main.models.ha.get_mem_effective_capacity', lambda mem, is_control_node: 64) @mock.patch('awx.main.models.ha.get_mem_effective_capacity', lambda mem, is_control_node: 62)
def test_job_capacity_and_with_inactive_node(): def test_job_capacity_and_with_inactive_node():
i = Instance.objects.create(hostname='test-1') i = Instance.objects.create(hostname='test-1')
i.save_health_data('18.0.1', 2, 8000) i.save_health_data('18.0.1', 2, 8000)
assert i.enabled is True assert i.enabled is True
assert i.capacity_adjustment == 0.75 assert i.capacity_adjustment == 1.0
assert i.capacity == 50 assert i.capacity == 62
i.enabled = False i.enabled = False
i.save() i.save()
with override_settings(CLUSTER_HOST_ID=i.hostname): with override_settings(CLUSTER_HOST_ID=i.hostname):
@@ -128,7 +128,7 @@ class TestAnsibleFactsSave:
assert inventory.hosts.count() == 3 assert inventory.hosts.count() == 3
Host.objects.get(pk=last_pk).delete() Host.objects.get(pk=last_pk).delete()
assert inventory.hosts.count() == 2 assert inventory.hosts.count() == 2
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts']) update_hosts(hosts)
assert inventory.hosts.count() == 2 assert inventory.hosts.count() == 2
for host in inventory.hosts.all(): for host in inventory.hosts.all():
host.refresh_from_db() host.refresh_from_db()
@@ -141,7 +141,7 @@ class TestAnsibleFactsSave:
db_mock = mocker.patch('awx.main.tasks.facts.Host.objects.bulk_update') db_mock = mocker.patch('awx.main.tasks.facts.Host.objects.bulk_update')
db_mock.side_effect = OperationalError('deadlock detected') db_mock.side_effect = OperationalError('deadlock detected')
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts']) update_hosts(hosts)
def fake_bulk_update(self, host_list): def fake_bulk_update(self, host_list):
if self.current_call > 2: if self.current_call > 2:
@@ -149,28 +149,16 @@ class TestAnsibleFactsSave:
self.current_call += 1 self.current_call += 1
raise OperationalError('deadlock detected') raise OperationalError('deadlock detected')
def test_update_hosts_resolved_deadlock(self, inventory, mocker):
@pytest.mark.django_db hosts = [Host.objects.create(inventory=inventory, name=f'foo{i}') for i in range(3)]
def test_update_hosts_resolved_deadlock(inventory, mocker): for host in hosts:
host.ansible_facts = {'foo': 'bar'}
hosts = [Host.objects.create(inventory=inventory, name=f'foo{i}') for i in range(3)] self.current_call = 0
mocker.patch('awx.main.tasks.facts.raw_update_hosts', new=self.fake_bulk_update)
# Set ansible_facts for each host update_hosts(hosts)
for host in hosts: for host in inventory.hosts.all():
host.ansible_facts = {'foo': 'bar'} host.refresh_from_db()
assert host.ansible_facts == {'foo': 'bar'}
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
# Save changes and refresh from DB to ensure the updated facts are saved
for host in hosts:
host.save() # Ensure changes are persisted in the DB
host.refresh_from_db() # Refresh from DB to get latest data
# Assert that the ansible_facts were updated correctly
for host in inventory.hosts.all():
assert host.ansible_facts == {'foo': 'bar'}
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
@pytest.mark.django_db @pytest.mark.django_db

View File

@@ -106,37 +106,3 @@ class TestMigrationSmoke:
) )
DABPermission = new_state.apps.get_model('dab_rbac', 'DABPermission') DABPermission = new_state.apps.get_model('dab_rbac', 'DABPermission')
assert not DABPermission.objects.filter(codename='view_executionenvironment').exists() assert not DABPermission.objects.filter(codename='view_executionenvironment').exists()
# Test create a Project with a duplicate name
Organization = new_state.apps.get_model('main', 'Organization')
Project = new_state.apps.get_model('main', 'Project')
org = Organization.objects.create(name='duplicate-obj-organization', created=now(), modified=now())
proj_ids = []
for i in range(3):
proj = Project.objects.create(name='duplicate-project-name', organization=org, created=now(), modified=now())
proj_ids.append(proj.id)
# The uniqueness rules will not apply to InventorySource
Inventory = new_state.apps.get_model('main', 'Inventory')
InventorySource = new_state.apps.get_model('main', 'InventorySource')
inv = Inventory.objects.create(name='migration-test-inv', organization=org, created=now(), modified=now())
InventorySource.objects.create(name='migration-test-src', source='file', inventory=inv, organization=org, created=now(), modified=now())
new_state = migrator.apply_tested_migration(
('main', '0200_template_name_constraint'),
)
for i, proj_id in enumerate(proj_ids):
proj = Project.objects.get(id=proj_id)
if i == 0:
assert proj.name == 'duplicate-project-name'
else:
assert proj.name != 'duplicate-project-name'
assert proj.name.startswith('duplicate-project-name')
# The inventory source had this field set to avoid the constrains
InventorySource = new_state.apps.get_model('main', 'InventorySource')
inv_src = InventorySource.objects.get(name='migration-test-src')
assert inv_src.org_unique is False
Project = new_state.apps.get_model('main', 'Project')
for proj in Project.objects.all():
assert proj.org_unique is True

View File

@@ -1,633 +0,0 @@
import json
import os
from unittest import mock
import pytest
import requests.exceptions
from django.test import override_settings
from awx.main.models import (
Job,
Inventory,
Project,
Organization,
JobTemplate,
Credential,
CredentialType,
User,
Team,
Label,
WorkflowJob,
WorkflowJobNode,
InventorySource,
)
from awx.main.exceptions import PolicyEvaluationError
from awx.main.tasks import policy
from awx.main.tasks.policy import JobSerializer, OPA_AUTH_TYPES
def _parse_exception_message(exception: PolicyEvaluationError):
pe_plain = str(exception.value)
assert "This job cannot be executed due to a policy violation or error. See the following details:" in pe_plain
violation_message = "This job cannot be executed due to a policy violation or error. See the following details:"
return eval(pe_plain.split(violation_message)[1].strip())
@pytest.fixture(autouse=True)
def enable_flag():
with override_settings(
OPA_HOST='opa.example.com',
FLAGS={"FEATURE_POLICY_AS_CODE_ENABLED": [("boolean", True)]},
FLAG_SOURCES=('flags.sources.SettingsFlagsSource',),
):
yield
@pytest.fixture
def opa_client():
cls_mock = mock.MagicMock(name='OpaClient')
instance_mock = cls_mock.return_value
instance_mock.__enter__.return_value = instance_mock
with mock.patch('awx.main.tasks.policy.OpaClient', cls_mock):
yield instance_mock
@pytest.fixture
def job():
project: Project = Project.objects.create(name='proj1', scm_type='git', scm_branch='main', scm_url='https://git.example.com/proj1')
inventory: Inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org: Organization = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt: JobTemplate = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job: Job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
return job
@pytest.mark.django_db
def test_job_serializer():
user: User = User.objects.create(username='user1')
org: Organization = Organization.objects.create(name='org1')
team: Team = Team.objects.create(name='team1', organization=org)
team.admin_role.members.add(user)
project: Project = Project.objects.create(name='proj1', scm_type='git', scm_branch='main', scm_url='https://git.example.com/proj1')
inventory: Inventory = Inventory.objects.create(name='inv1', description='Demo inventory')
inventory_source: InventorySource = InventorySource.objects.create(name='inv-src1', source='file', inventory=inventory)
extra_vars = {"FOO": "value1", "BAR": "value2"}
CredentialType.setup_tower_managed_defaults()
cred_type_ssh: CredentialType = CredentialType.objects.get(kind='ssh')
cred: Credential = Credential.objects.create(name="cred1", description='Demo credential', credential_type=cred_type_ssh, organization=org)
label: Label = Label.objects.create(name='label1', organization=org)
job: Job = Job.objects.create(
name='job1', extra_vars=json.dumps(extra_vars), inventory=inventory, project=project, organization=org, created_by=user, launch_type='workflow'
)
# job.unified_job_node.workflow_job = workflow_job
job.credentials.add(cred)
job.labels.add(label)
workflow_job: WorkflowJob = WorkflowJob.objects.create(name='wf-job1')
WorkflowJobNode.objects.create(job=job, workflow_job=workflow_job)
serializer = JobSerializer(instance=job)
assert serializer.data == {
'id': job.id,
'name': 'job1',
'created': job.created.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
'created_by': {
'id': user.id,
'username': 'user1',
'is_superuser': False,
'teams': [
{'id': team.id, 'name': 'team1'},
],
},
'credentials': [
{
'id': cred.id,
'name': 'cred1',
'description': 'Demo credential',
'organization': {
'id': org.id,
'name': 'org1',
},
'credential_type': cred_type_ssh.id,
'kind': 'ssh',
'managed': False,
'kubernetes': False,
'cloud': False,
},
],
'execution_environment': None,
'extra_vars': extra_vars,
'forks': 0,
'hosts_count': 0,
'instance_group': None,
'inventory': {
'id': inventory.id,
'name': 'inv1',
'description': 'Demo inventory',
'kind': '',
'total_hosts': 0,
'total_groups': 0,
'has_inventory_sources': False,
'total_inventory_sources': 0,
'has_active_failures': False,
'hosts_with_active_failures': 0,
'inventory_sources': [
{
'id': inventory_source.id,
'name': 'inv-src1',
'source': 'file',
'status': 'never updated',
}
],
},
'job_template': None,
'job_type': 'run',
'job_type_name': 'job',
'labels': [
{
'id': label.id,
'name': 'label1',
'organization': {
'id': org.id,
'name': 'org1',
},
},
],
'launch_type': 'workflow',
'limit': '',
'launched_by': {},
'organization': {
'id': org.id,
'name': 'org1',
},
'playbook': '',
'project': {
'id': project.id,
'name': 'proj1',
'status': 'pending',
'scm_type': 'git',
'scm_url': 'https://git.example.com/proj1',
'scm_branch': 'main',
'scm_refspec': '',
'scm_clean': False,
'scm_track_submodules': False,
'scm_delete_on_update': False,
},
'scm_branch': '',
'scm_revision': '',
'workflow_job': {
'id': workflow_job.id,
'name': 'wf-job1',
},
'workflow_job_template': None,
}
@pytest.mark.django_db
def test_evaluate_policy_missing_opa_query_path_field(opa_client):
project: Project = Project.objects.create(name='proj1', scm_type='git', scm_branch='main', scm_url='https://git.example.com/proj1')
inventory: Inventory = Inventory.objects.create(name='inv1')
org: Organization = Organization.objects.create(name="org1")
jt: JobTemplate = JobTemplate.objects.create(name="jt1")
job: Job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
response = {
"result": {
"allowed": True,
"violations": [],
}
}
opa_client.query_rule.return_value = response
try:
policy.evaluate_policy(job)
except PolicyEvaluationError as e:
pytest.fail(f"Must not raise PolicyEvaluationError: {e}")
assert opa_client.query_rule.call_count == 0
@pytest.mark.django_db
def test_evaluate_policy(opa_client, job):
response = {
"result": {
"allowed": True,
"violations": [],
}
}
opa_client.query_rule.return_value = response
try:
policy.evaluate_policy(job)
except PolicyEvaluationError as e:
pytest.fail(f"Must not raise PolicyEvaluationError: {e}")
opa_client.query_rule.assert_has_calls(
[
mock.call(input_data=mock.ANY, package_path='organization/response'),
mock.call(input_data=mock.ANY, package_path='inventory/response'),
mock.call(input_data=mock.ANY, package_path='job_template/response'),
],
any_order=False,
)
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_allowed(opa_client, job):
response = {
"result": {
"allowed": True,
"violations": [],
}
}
opa_client.query_rule.return_value = response
try:
policy.evaluate_policy(job)
except PolicyEvaluationError as e:
pytest.fail(f"Must not raise PolicyEvaluationError: {e}")
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_not_allowed(opa_client, job):
response = {
"result": {
"allowed": False,
"violations": ["Access not allowed."],
}
}
opa_client.query_rule.return_value = response
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
pe_plain = str(pe.value)
assert "Errors:" not in pe_plain
exception = _parse_exception_message(pe)
assert exception["Violations"]["Organization"] == ["Access not allowed."]
assert exception["Violations"]["Inventory"] == ["Access not allowed."]
assert exception["Violations"]["Job template"] == ["Access not allowed."]
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_not_found(opa_client, job):
response = {}
opa_client.query_rule.return_value = response
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
missing_result_property = 'Call to OPA did not return a "result" property. The path refers to an undefined document.'
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == missing_result_property
assert exception["Errors"]["Inventory"] == missing_result_property
assert exception["Errors"]["Job template"] == missing_result_property
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_server_error(opa_client, job):
http_error_msg = '500 Server Error: Internal Server Error for url: https://opa.example.com:8181/v1/data/job_template/response/invalid'
error_response = {
'code': 'internal_error',
'message': (
'1 error occurred: 1:1: rego_type_error: undefined ref: data.job_template.response.invalid\n\t'
'data.job_template.response.invalid\n\t'
' ^\n\t'
' have: "invalid"\n\t'
' want (one of): ["allowed" "violations"]'
),
}
response = mock.Mock()
response.status_code = requests.codes.internal_server_error
response.json.return_value = error_response
opa_client.query_rule.side_effect = requests.exceptions.HTTPError(http_error_msg, response=response)
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == f'Call to OPA failed. Code: internal_error, Message: {error_response["message"]}'
assert exception["Errors"]["Inventory"] == f'Call to OPA failed. Code: internal_error, Message: {error_response["message"]}'
assert exception["Errors"]["Job template"] == f'Call to OPA failed. Code: internal_error, Message: {error_response["message"]}'
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_invalid_result(opa_client, job):
response = {
"result": {
"absolutely": "no!",
}
}
opa_client.query_rule.return_value = response
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
invalid_result = 'OPA policy returned invalid result.'
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == invalid_result
assert exception["Errors"]["Inventory"] == invalid_result
assert exception["Errors"]["Job template"] == invalid_result
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_failed_exception(opa_client, job):
error_response = {}
response = mock.Mock()
response.status_code = requests.codes.internal_server_error
response.json.return_value = error_response
opa_client.query_rule.side_effect = ValueError("Invalid JSON")
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
opa_failed_exception = 'Call to OPA failed. Exception: Invalid JSON'
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == opa_failed_exception
assert exception["Errors"]["Inventory"] == opa_failed_exception
assert exception["Errors"]["Job template"] == opa_failed_exception
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
@pytest.mark.parametrize(
"settings_kwargs, expected_client_cert, expected_verify, verify_content",
[
# Case 1: Certificate-based authentication (mTLS)
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": True,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.CERTIFICATE,
"OPA_AUTH_CLIENT_CERT": "-----BEGIN CERTIFICATE-----\nMIICert\n-----END CERTIFICATE-----",
"OPA_AUTH_CLIENT_KEY": "-----BEGIN PRIVATE KEY-----\nMIIKey\n-----END PRIVATE KEY-----",
"OPA_AUTH_CA_CERT": "-----BEGIN CERTIFICATE-----\nMIICACert\n-----END CERTIFICATE-----",
},
True, # Client cert should be created
"file", # Verify path should be a file
"-----BEGIN CERTIFICATE-----", # Expected content in verify file
),
# Case 2: SSL with server verification only
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": True,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.NONE,
"OPA_AUTH_CA_CERT": "-----BEGIN CERTIFICATE-----\nMIICACert\n-----END CERTIFICATE-----",
},
False, # No client cert should be created
"file", # Verify path should be a file
"-----BEGIN CERTIFICATE-----", # Expected content in verify file
),
# Case 3: SSL with system CA store
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": True,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.NONE,
"OPA_AUTH_CA_CERT": "", # No custom CA cert
},
False, # No client cert should be created
True, # Verify path should be True (system CA store)
None, # No file to check content
),
# Case 4: No SSL
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": False,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.NONE,
},
False, # No client cert should be created
False, # Verify path should be False (no verification)
None, # No file to check content
),
],
ids=[
"certificate_auth",
"ssl_server_verification",
"ssl_system_ca_store",
"no_ssl",
],
)
def test_opa_cert_file(settings_kwargs, expected_client_cert, expected_verify, verify_content):
"""Parameterized test for the opa_cert_file context manager.
Tests different configurations:
- Certificate-based authentication (mTLS)
- SSL with server verification only
- SSL with system CA store
- No SSL
"""
with override_settings(**settings_kwargs):
client_cert_path = None
verify_path = None
with policy.opa_cert_file() as cert_files:
client_cert_path, verify_path = cert_files
# Check client cert based on expected_client_cert
if expected_client_cert:
assert client_cert_path is not None
with open(client_cert_path, 'r') as f:
content = f.read()
assert "-----BEGIN CERTIFICATE-----" in content
assert "-----BEGIN PRIVATE KEY-----" in content
else:
assert client_cert_path is None
# Check verify path based on expected_verify
if expected_verify == "file":
assert verify_path is not None
assert os.path.isfile(verify_path)
with open(verify_path, 'r') as f:
content = f.read()
assert verify_content in content
else:
assert verify_path is expected_verify
# Verify files are deleted after context manager exits
if expected_client_cert:
assert not os.path.exists(client_cert_path), "Client cert file was not deleted"
if expected_verify == "file":
assert not os.path.exists(verify_path), "CA cert file was not deleted"
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_SSL=False, # SSL disabled
OPA_AUTH_TYPE=OPA_AUTH_TYPES.CERTIFICATE, # But cert auth enabled
OPA_AUTH_CLIENT_CERT="-----BEGIN CERTIFICATE-----\nMIICert\n-----END CERTIFICATE-----",
OPA_AUTH_CLIENT_KEY="-----BEGIN PRIVATE KEY-----\nMIIKey\n-----END PRIVATE KEY-----",
)
def test_evaluate_policy_cert_auth_requires_ssl():
"""Test that policy evaluation raises an error when certificate auth is used without SSL."""
project = Project.objects.create(name='proj1')
inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
assert "OPA_AUTH_TYPE=Certificate requires OPA_SSL to be enabled" in str(pe.value)
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_SSL=True,
OPA_AUTH_TYPE=OPA_AUTH_TYPES.CERTIFICATE,
OPA_AUTH_CLIENT_CERT="", # Missing client cert
OPA_AUTH_CLIENT_KEY="", # Missing client key
OPA_AUTH_CA_CERT="", # Missing CA cert
)
def test_evaluate_policy_missing_cert_settings():
"""Test that policy evaluation raises an error when certificate settings are missing."""
project = Project.objects.create(name='proj1')
inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
error_msg = str(pe.value)
assert "Following certificate settings are missing for OPA_AUTH_TYPE=Certificate:" in error_msg
assert "OPA_AUTH_CLIENT_CERT" in error_msg
assert "OPA_AUTH_CLIENT_KEY" in error_msg
assert "OPA_AUTH_CA_CERT" in error_msg
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_PORT=8181,
OPA_SSL=True,
OPA_AUTH_TYPE=OPA_AUTH_TYPES.CERTIFICATE,
OPA_AUTH_CLIENT_CERT="-----BEGIN CERTIFICATE-----\nMIICert\n-----END CERTIFICATE-----",
OPA_AUTH_CLIENT_KEY="-----BEGIN PRIVATE KEY-----\nMIIKey\n-----END PRIVATE KEY-----",
OPA_AUTH_CA_CERT="-----BEGIN CERTIFICATE-----\nMIICACert\n-----END CERTIFICATE-----",
OPA_REQUEST_TIMEOUT=2.5,
OPA_REQUEST_RETRIES=3,
)
def test_opa_client_context_manager_mtls():
"""Test that opa_client context manager correctly initializes the OPA client."""
# Mock the OpaClient class
with mock.patch('awx.main.tasks.policy.OpaClient') as mock_opa_client:
# Setup the mock
mock_instance = mock_opa_client.return_value
mock_instance.__enter__.return_value = mock_instance
mock_instance._session = mock.MagicMock()
# Use the context manager
with policy.opa_client(headers={'Custom-Header': 'Value'}) as client:
# Verify the client was initialized with the correct parameters
mock_opa_client.assert_called_once_with(
host='opa.example.com',
port=8181,
headers={'Custom-Header': 'Value'},
ssl=True,
cert=mock.ANY, # We can't check the exact value as it's a temporary file
timeout=2.5,
retries=3,
)
# Verify the session properties were set correctly
assert client._session.cert is not None
assert client._session.verify is not None
# Check the content of the cert file
cert_file_path = client._session.cert
assert os.path.isfile(cert_file_path)
with open(cert_file_path, 'r') as f:
cert_content = f.read()
assert "-----BEGIN CERTIFICATE-----" in cert_content
assert "MIICert" in cert_content
assert "-----BEGIN PRIVATE KEY-----" in cert_content
assert "MIIKey" in cert_content
# Check the content of the verify file
verify_file_path = client._session.verify
assert os.path.isfile(verify_file_path)
with open(verify_file_path, 'r') as f:
verify_content = f.read()
assert "-----BEGIN CERTIFICATE-----" in verify_content
assert "MIICACert" in verify_content
# Verify the client is the mocked instance
assert client is mock_instance
# Store file paths for checking after context exit
cert_path = client._session.cert
verify_path = client._session.verify
# Verify files are deleted after context manager exits
assert not os.path.exists(cert_path), "Client cert file was not deleted"
assert not os.path.exists(verify_path), "CA cert file was not deleted"
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_SSL=True,
OPA_AUTH_TYPE=OPA_AUTH_TYPES.TOKEN,
OPA_AUTH_TOKEN='secret-token',
OPA_AUTH_CUSTOM_HEADERS={'X-Custom': 'Header'},
)
def test_opa_client_token_auth():
"""Test that token authentication correctly adds the Authorization header."""
# Create a job for testing
project = Project.objects.create(name='proj1')
inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
# Mock the OpaClient class
with mock.patch('awx.main.tasks.policy.opa_client') as mock_opa_client_cm:
# Setup the mock
mock_client = mock.MagicMock()
mock_opa_client_cm.return_value.__enter__.return_value = mock_client
mock_client.query_rule.return_value = {
"result": {
"allowed": True,
"violations": [],
}
}
# Call evaluate_policy
policy.evaluate_policy(job)
# Verify opa_client was called with the correct headers
expected_headers = {'X-Custom': 'Header', 'Authorization': 'Bearer secret-token'}
mock_opa_client_cm.assert_called_once_with(headers=expected_headers)

View File

@@ -436,22 +436,21 @@ def test_project_list_ordering_by_name(get, order_by, expected_names, organizati
@pytest.mark.parametrize('order_by', ('name', '-name')) @pytest.mark.parametrize('order_by', ('name', '-name'))
@pytest.mark.django_db @pytest.mark.django_db
def test_project_list_ordering_with_duplicate_names(get, order_by, admin): def test_project_list_ordering_with_duplicate_names(get, order_by, organization_factory):
# why? because all the '1' mean that all the names are the same, you can't sort based on that, # why? because all the '1' mean that all the names are the same, you can't sort based on that,
# meaning you have to fall back on the default sort order, which in this case, is ID # meaning you have to fall back on the default sort order, which in this case, is ID
'ensure sorted order of project list is maintained correctly when the project names the same' 'ensure sorted order of project list is maintained correctly when the project names the same'
from awx.main.models import Organization objects = organization_factory(
'org1',
projects = [] projects=['1', '1', '1', '1', '1'],
for i in range(5): superusers=['admin'],
projects.append(Project.objects.create(name='1', organization=Organization.objects.create(name=f'org{i}'))) )
project_ids = {} project_ids = {}
for x in range(3): for x in range(3):
results = get(reverse('api:project_list'), user=admin, QUERY_STRING='order_by=%s' % order_by).data['results'] results = get(reverse('api:project_list'), objects.superusers.admin, QUERY_STRING='order_by=%s' % order_by).data['results']
project_ids[x] = [proj['id'] for proj in results] project_ids[x] = [proj['id'] for proj in results]
assert project_ids[0] == project_ids[1] == project_ids[2] assert project_ids[0] == project_ids[1] == project_ids[2]
assert project_ids[0] == sorted(project_ids[0]) assert project_ids[0] == sorted(project_ids[0])
assert set(project_ids[0]) == set([proj.id for proj in projects])
@pytest.mark.django_db @pytest.mark.django_db

View File

@@ -36,7 +36,7 @@ def test_bootstrap_consistent():
assert not different_requirements assert not different_requirements
@pytest.mark.xfail(reason="This test needs some love") @pytest.mark.skip(reason="This test needs some love")
def test_env_matches_requirements_txt(): def test_env_matches_requirements_txt():
from pip.operations import freeze from pip.operations import freeze

View File

@@ -74,7 +74,7 @@ class TestWebsocketEventConsumer:
connected, _ = await server.connect() connected, _ = await server.connect()
assert connected is False, "Anonymous user should NOT be allowed to login." assert connected is False, "Anonymous user should NOT be allowed to login."
@pytest.mark.xfail(reason="Ran out of coding time.") @pytest.mark.skip(reason="Ran out of coding time.")
async def test_authorized(self, websocket_server_generator, application, admin): async def test_authorized(self, websocket_server_generator, application, admin):
server = websocket_server_generator('/websocket/') server = websocket_server_generator('/websocket/')

View File

@@ -1,77 +0,0 @@
import multiprocessing
import json
import pytest
import requests
from requests.auth import HTTPBasicAuth
from django.db import connection
from awx.main.models import User, JobTemplate
def create_in_subprocess(project_id, ready_event, continue_event, admin_auth):
connection.connect()
print('setting ready event')
ready_event.set()
print('waiting for continue event')
continue_event.wait()
if JobTemplate.objects.filter(name='test_jt_duplicate_name').exists():
for jt in JobTemplate.objects.filter(name='test_jt_duplicate_name'):
jt.delete()
assert JobTemplate.objects.filter(name='test_jt_duplicate_name').count() == 0
jt_data = {'name': 'test_jt_duplicate_name', 'project': project_id, 'playbook': 'hello_world.yml', 'ask_inventory_on_launch': True}
response = requests.post('http://localhost:8013/api/v2/job_templates/', json=jt_data, auth=admin_auth)
# should either have a conflict or create
assert response.status_code in (400, 201)
print(f'Subprocess got {response.status_code}')
if response.status_code == 400:
print(json.dumps(response.json(), indent=2))
return response.status_code
@pytest.fixture
def admin_for_test():
user, created = User.objects.get_or_create(username='admin_for_test', defaults={'is_superuser': True})
if created:
user.set_password('for_test_123!')
user.save()
print(f'Created user {user.username}')
return user
@pytest.fixture
def admin_auth(admin_for_test):
return HTTPBasicAuth(admin_for_test.username, 'for_test_123!')
def test_jt_duplicate_name(admin_auth, demo_proj):
N_processes = 5
ready_events = [multiprocessing.Event() for _ in range(N_processes)]
continue_event = multiprocessing.Event()
processes = []
for i in range(N_processes):
p = multiprocessing.Process(target=create_in_subprocess, args=(demo_proj.id, ready_events[i], continue_event, admin_auth))
processes.append(p)
p.start()
# Assure both processes are connected and have loaded their host list
for e in ready_events:
print('waiting on subprocess ready event')
e.wait()
# Begin the bulk_update queries
print('setting the continue event for the workers')
continue_event.set()
# if a Deadloack happens it will probably be surfaced by result here
print('waiting on the workers to finish the creation')
for p in processes:
p.join()
assert JobTemplate.objects.filter(name='test_jt_duplicate_name').count() == 1

View File

@@ -3,7 +3,6 @@ import time
import os import os
import shutil import shutil
import tempfile import tempfile
import logging
import pytest import pytest
@@ -20,9 +19,6 @@ from awx.main.tests import data
from awx.main.models import Project, JobTemplate, Organization, Inventory from awx.main.models import Project, JobTemplate, Organization, Inventory
logger = logging.getLogger(__name__)
PROJ_DATA = os.path.join(os.path.dirname(data.__file__), 'projects') PROJ_DATA = os.path.join(os.path.dirname(data.__file__), 'projects')
@@ -114,12 +110,6 @@ def demo_inv(default_org):
return inventory return inventory
@pytest.fixture(scope='session')
def demo_proj(default_org):
proj, _ = Project.objects.get_or_create(name='Demo Project', defaults={'organization': default_org})
return proj
@pytest.fixture @pytest.fixture
def podman_image_generator(): def podman_image_generator():
""" """
@@ -138,29 +128,30 @@ def podman_image_generator():
@pytest.fixture @pytest.fixture
def project_factory(post, default_org, admin): def run_job_from_playbook(default_org, demo_inv, post, admin):
def _rf(scm_url=None, local_path=None): def _rf(test_name, playbook, local_path=None, scm_url=None):
proj_kwargs = {} project_name = f'{test_name} project'
jt_name = f'{test_name} JT: {playbook}'
old_proj = Project.objects.filter(name=project_name).first()
if old_proj:
old_proj.delete()
old_jt = JobTemplate.objects.filter(name=jt_name).first()
if old_jt:
old_jt.delete()
proj_kwargs = {'name': project_name, 'organization': default_org.id}
if local_path: if local_path:
# manual path # manual path
project_name = f'Manual roject {local_path}'
proj_kwargs['scm_type'] = '' proj_kwargs['scm_type'] = ''
proj_kwargs['local_path'] = local_path proj_kwargs['local_path'] = local_path
elif scm_url: elif scm_url:
project_name = f'Project {scm_url}'
proj_kwargs['scm_type'] = 'git' proj_kwargs['scm_type'] = 'git'
proj_kwargs['scm_url'] = scm_url proj_kwargs['scm_url'] = scm_url
else: else:
raise RuntimeError('Need to provide scm_url or local_path') raise RuntimeError('Need to provide scm_url or local_path')
proj_kwargs['name'] = project_name
proj_kwargs['organization'] = default_org.id
old_proj = Project.objects.filter(name=project_name).first()
if old_proj:
logger.info(f'Deleting existing project {project_name}')
old_proj.delete()
result = post( result = post(
reverse('api:project_list'), reverse('api:project_list'),
proj_kwargs, proj_kwargs,
@@ -168,23 +159,6 @@ def project_factory(post, default_org, admin):
expect=201, expect=201,
) )
proj = Project.objects.get(id=result.data['id']) proj = Project.objects.get(id=result.data['id'])
return proj
return _rf
@pytest.fixture
def run_job_from_playbook(demo_inv, post, admin, project_factory):
def _rf(test_name, playbook, local_path=None, scm_url=None, jt_params=None, proj=None):
jt_name = f'{test_name} JT: {playbook}'
if not proj:
proj = project_factory(scm_url=scm_url, local_path=local_path)
old_jt = JobTemplate.objects.filter(name=jt_name).first()
if old_jt:
logger.info(f'Deleting existing JT {jt_name}')
old_jt.delete()
if proj.current_job: if proj.current_job:
wait_for_job(proj.current_job) wait_for_job(proj.current_job)
@@ -192,13 +166,9 @@ def run_job_from_playbook(demo_inv, post, admin, project_factory):
assert proj.get_project_path() assert proj.get_project_path()
assert playbook in proj.playbooks assert playbook in proj.playbooks
jt_data = {'name': jt_name, 'project': proj.id, 'playbook': playbook, 'inventory': demo_inv.id}
if jt_params:
jt_data.update(jt_params)
result = post( result = post(
reverse('api:job_template_list'), reverse('api:job_template_list'),
jt_data, {'name': jt_name, 'project': proj.id, 'playbook': playbook, 'inventory': demo_inv.id},
admin, admin,
expect=201, expect=201,
) )
@@ -209,6 +179,4 @@ def run_job_from_playbook(demo_inv, post, admin, project_factory):
wait_for_job(job) wait_for_job(job)
assert job.status == 'successful' assert job.status == 'successful'
return {'job': job, 'job_template': jt, 'project': proj}
return _rf return _rf

View File

@@ -1,70 +0,0 @@
import pytest
from awx.main.tests.live.tests.conftest import wait_for_events, wait_for_job
from awx.main.models import Job, Inventory
@pytest.fixture
def facts_project(live_tmp_folder, project_factory):
return project_factory(scm_url=f'file://{live_tmp_folder}/facts')
def assert_facts_populated(name):
job = Job.objects.filter(name__icontains=name).order_by('-created').first()
assert job is not None
wait_for_events(job)
wait_for_job(job)
inventory = job.inventory
assert inventory.hosts.count() > 0 # sanity
for host in inventory.hosts.all():
assert host.ansible_facts
@pytest.fixture
def general_facts_test(facts_project, run_job_from_playbook):
def _rf(slug, jt_params):
jt_params['use_fact_cache'] = True
standard_kwargs = dict(jt_params=jt_params)
# GATHER FACTS
name = f'test_gather_ansible_facts_{slug}'
run_job_from_playbook(name, 'gather.yml', proj=facts_project, **standard_kwargs)
assert_facts_populated(name)
# KEEP FACTS
name = f'test_clear_ansible_facts_{slug}'
run_job_from_playbook(name, 'no_op.yml', proj=facts_project, **standard_kwargs)
assert_facts_populated(name)
# CLEAR FACTS
name = f'test_clear_ansible_facts_{slug}'
run_job_from_playbook(name, 'clear.yml', proj=facts_project, **standard_kwargs)
job = Job.objects.filter(name__icontains=name).order_by('-created').first()
assert job is not None
wait_for_events(job)
inventory = job.inventory
assert inventory.hosts.count() > 0 # sanity
for host in inventory.hosts.all():
assert not host.ansible_facts
return _rf
def test_basic_ansible_facts(general_facts_test):
general_facts_test('basic', {})
@pytest.fixture
def sliced_inventory():
inv, _ = Inventory.objects.get_or_create(name='inventory-to-slice')
if not inv.hosts.exists():
for i in range(10):
inv.hosts.create(name=f'sliced_host_{i}')
return inv
def test_slicing_with_facts(general_facts_test, sliced_inventory):
general_facts_test('sliced', {'job_slice_count': 3, 'inventory': sliced_inventory.id})

View File

@@ -1,78 +0,0 @@
import multiprocessing
import random
from django.db import connection
from django.utils.timezone import now
from awx.main.models import Inventory, Host
from awx.main.utils.db import bulk_update_sorted_by_id
def worker_delete_target(ready_event, continue_event, field_name):
"""Runs the bulk update, will be called in duplicate, in parallel"""
inv = Inventory.objects.get(organization__name='Default', name='test_host_update_contention')
host_list = list(inv.hosts.all())
# Using random.shuffle for non-security-critical shuffling in a test
random.shuffle(host_list) # NOSONAR
for i, host in enumerate(host_list):
setattr(host, field_name, f'my_var: {i}')
# ready to do the bulk_update
print('worker has loaded all the hosts needed')
ready_event.set()
# wait for the coordination message
continue_event.wait()
# NOTE: did not reproduce the bug without batch_size
bulk_update_sorted_by_id(Host, host_list, fields=[field_name], batch_size=100)
print('finished doing the bulk update in worker')
def test_host_update_contention(default_org):
inv_kwargs = dict(organization=default_org, name='test_host_update_contention')
if Inventory.objects.filter(**inv_kwargs).exists():
inv = Inventory.objects.get(**inv_kwargs).delete()
inv = Inventory.objects.create(**inv_kwargs)
right_now = now()
hosts = [Host(inventory=inv, name=f'host-{i}', created=right_now, modified=right_now) for i in range(1000)]
print('bulk creating hosts')
Host.objects.bulk_create(hosts)
# sanity check
for host in hosts:
assert not host.variables
# Force our worker pool to make their own connection
connection.close()
ready_events = [multiprocessing.Event() for _ in range(2)]
continue_event = multiprocessing.Event()
print('spawning processes for concurrent bulk updates')
processes = []
fields = ['variables', 'ansible_facts']
for i in range(2):
p = multiprocessing.Process(target=worker_delete_target, args=(ready_events[i], continue_event, fields[i]))
processes.append(p)
p.start()
# Assure both processes are connected and have loaded their host list
for e in ready_events:
print('waiting on subprocess ready event')
e.wait()
# Begin the bulk_update queries
print('setting the continue event for the workers')
continue_event.set()
# if a Deadloack happens it will probably be surfaced by result here
print('waiting on the workers to finish the bulk_update')
for p in processes:
p.join()
print('checking workers have variables set')
for host in inv.hosts.all():
assert host.variables.startswith('my_var:')
assert host.ansible_facts.startswith('my_var:')

View File

@@ -49,15 +49,22 @@ def test_indirect_host_counting(live_tmp_folder, run_job_from_playbook):
# Task might not run due to race condition, so make it run here # Task might not run due to race condition, so make it run here
job.refresh_from_db() job.refresh_from_db()
if job.event_queries_processed is False: if job.event_queries_processed is False:
save_indirect_host_entries.delay(job.id, wait_for_events=False) for _ in range(10):
save_indirect_host_entries.delay(job.id, wait_for_events=True)
job.refresh_from_db()
if job.event_queries_processed is True:
break
time.sleep(0.5)
else:
raise RuntimeError(f'Job events not received for job_id={job.id}')
# event_queries_processed only assures the task has started, it might take a minor amount of time to finish # This will poll for the background task to finish
for _ in range(10): for _ in range(10):
if IndirectManagedNodeAudit.objects.filter(job=job).exists(): if IndirectManagedNodeAudit.objects.filter(job=job).exists():
break break
time.sleep(0.2) time.sleep(0.2)
else: else:
raise RuntimeError(f'No IndirectManagedNodeAudit records ever populated for job_id={job.id}') raise RuntimeError(f'No IndirectManagedNodeAudit records ever populated for job_id={job.id}')
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1 assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first() host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()

View File

@@ -1,224 +0,0 @@
import subprocess
import time
import os.path
from urllib.parse import urlsplit
import pytest
from unittest import mock
from awx.main.models.projects import Project
from awx.main.models.organization import Organization
from awx.main.models.inventory import Inventory, InventorySource
from awx.main.tests.live.tests.conftest import wait_for_job
NAME_PREFIX = "test-ivu"
GIT_REPO_FOLDER = "inventory_vars"
def create_new_by_name(model, **kwargs):
"""
Create a new model instance. Delete an existing instance first.
:param model: The Django model.
:param dict kwargs: The keyword arguments required to create a model
instance. Must contain at least `name`.
:return: The model instance.
"""
name = kwargs["name"]
try:
instance = model.objects.get(name=name)
except model.DoesNotExist:
pass
else:
print(f"FORCE DELETE {name}")
instance.delete()
finally:
instance = model.objects.create(**kwargs)
return instance
def wait_for_update(instance, timeout=3.0):
"""Wait until the last update of *instance* is finished."""
start = time.time()
while time.time() - start < timeout:
if instance.current_job or instance.last_job or instance.last_job_run:
break
time.sleep(0.2)
assert instance.current_job or instance.last_job or instance.last_job_run, f'Instance never updated id={instance.id}'
update = instance.current_job or instance.last_job
if update:
wait_for_job(update)
def change_source_vars_and_update(invsrc, group_vars):
"""
Change the variables content of an inventory source and update its
inventory.
Does not return before the inventory update is finished.
:param invsrc: The inventory source instance.
:param dict group_vars: The variables for various groups. Format::
{
<group>: {<variable>: <value>, <variable>: <value>, ..}, <group>:
{<variable>: <value>, <variable>: <value>, ..}, ..
}
:return: None
"""
project = invsrc.source_project
repo_path = urlsplit(project.scm_url).path
filepath = os.path.join(repo_path, invsrc.source_path)
# print(f"change_source_vars_and_update: {project=} {repo_path=} {filepath=}")
with open(filepath, "w") as fp:
for group, variables in group_vars.items():
fp.write(f"[{group}:vars]\n")
for name, value in variables.items():
fp.write(f"{name}={value}\n")
subprocess.run('git add .; git commit -m "Update variables in invsrc.source_path"', cwd=repo_path, shell=True)
# Update the project to sync the changed repo contents.
project.update()
wait_for_update(project)
# Update the inventory from the changed source.
invsrc.update()
wait_for_update(invsrc)
@pytest.fixture
def organization():
name = f"{NAME_PREFIX}-org"
instance = create_new_by_name(Organization, name=name, description=f"Description for {name}")
yield instance
instance.delete()
@pytest.fixture
def project(organization, live_tmp_folder):
name = f"{NAME_PREFIX}-project"
instance = create_new_by_name(
Project,
name=name,
description=f"Description for {name}",
organization=organization,
scm_url=f"file://{live_tmp_folder}/{GIT_REPO_FOLDER}",
scm_type="git",
)
yield instance
instance.delete()
@pytest.fixture
def inventory(organization):
name = f"{NAME_PREFIX}-inventory"
instance = create_new_by_name(
Inventory,
name=name,
description=f"Description for {name}",
organization=organization,
)
yield instance
instance.delete()
@pytest.fixture
def inventory_source(inventory, project):
name = f"{NAME_PREFIX}-invsrc"
inv_src = InventorySource(
name=name,
source_project=project,
source="scm",
source_path="inventory_var_deleted_in_source.ini",
inventory=inventory,
overwrite_vars=True,
)
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
inv_src.save()
yield inv_src
inv_src.delete()
@pytest.fixture
def inventory_source_factory(inventory, project):
"""
Use this fixture if you want to use multiple inventory sources for the same
inventory in your test.
"""
# https://docs.pytest.org/en/stable/how-to/fixtures.html#factories-as-fixtures
created = []
# repo_path = f"{live_tmp_folder}/{GIT_REPO_FOLDER}"
def _factory(inventory_file, name):
# Make sure the inventory file exists before the inventory source
# instance is created.
#
# Note: The current implementation of the inventory source object allows
# to create an instance even when the inventory source file does not
# exist. If this behaviour changes, uncomment the following code block
# and add the fixture `live_tmp_folder` to the factory function
# signature.
#
# inventory_file_path = os.path.join(repo_path, inventory_file) if not
# os.path.isfile(inventory_file_path): with open(inventory_file_path,
# "w") as fp: pass subprocess.run(f'git add .; git commit -m "Create
# {inventory_file_path}"', cwd=repo_path, shell=True)
#
# Create the inventory source instance.
name = f"{NAME_PREFIX}-invsrc-{name}"
inv_src = InventorySource(
name=name,
source_project=project,
source="scm",
source_path=inventory_file,
inventory=inventory,
overwrite_vars=True,
)
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
inv_src.save()
return inv_src
yield _factory
for instance in created:
instance.delete()
def test_inventory_var_deleted_in_source(inventory, inventory_source):
"""
Verify that a variable which is deleted from its (git-)source between two
updates is also deleted from the inventory.
Verifies https://issues.redhat.com/browse/AAP-17690
"""
inventory_source.update()
wait_for_update(inventory_source)
assert {"a": "value_a", "b": "value_b"} == Inventory.objects.get(name=inventory.name).variables_dict
# Remove variable `a` from source and verify that it is also removed from
# the inventory variables.
change_source_vars_and_update(inventory_source, {"all": {"b": "value_b"}})
assert {"b": "value_b"} == Inventory.objects.get(name=inventory.name).variables_dict
def test_inventory_vars_with_multiple_sources(inventory, inventory_source_factory):
"""
Verify a sequence of updates from various sources with changing content.
"""
invsrc_a = inventory_source_factory("invsrc_a.ini", "A")
invsrc_b = inventory_source_factory("invsrc_b.ini", "B")
invsrc_c = inventory_source_factory("invsrc_c.ini", "C")
change_source_vars_and_update(invsrc_a, {"all": {"x": "x_from_a", "y": "y_from_a"}})
assert {"x": "x_from_a", "y": "y_from_a"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_b, {"all": {"x": "x_from_b", "y": "y_from_b", "z": "z_from_b"}})
assert {"x": "x_from_b", "y": "y_from_b", "z": "z_from_b"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_c, {"all": {"x": "x_from_c", "z": "z_from_c"}})
assert {"x": "x_from_c", "y": "y_from_b", "z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_b, {"all": {}})
assert {"x": "x_from_c", "y": "y_from_a", "z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_c, {"all": {"z": "z_from_c"}})
assert {"x": "x_from_a", "y": "y_from_a", "z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_a, {"all": {}})
assert {"z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_c, {"all": {}})
assert {} == Inventory.objects.get(name=inventory.name).variables_dict

View File

@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import os import os
import time
import pytest import pytest
from awx.main.models import ( from awx.main.models import (
@@ -13,8 +15,6 @@ from django.utils.timezone import now
from datetime import timedelta from datetime import timedelta
import time
@pytest.fixture @pytest.fixture
def ref_time(): def ref_time():
@@ -33,23 +33,15 @@ def hosts(ref_time):
def test_start_job_fact_cache(hosts, tmpdir): def test_start_job_fact_cache(hosts, tmpdir):
# Create artifacts dir inside tmpdir fact_cache = os.path.join(tmpdir, 'facts')
artifacts_dir = tmpdir.mkdir("artifacts") last_modified = start_fact_cache(hosts, fact_cache, timeout=0)
# Assign a mock inventory ID
inventory_id = 42
# Call the function WITHOUT log_data — the decorator handles it
start_fact_cache(hosts, artifacts_dir=str(artifacts_dir), timeout=0, inventory_id=inventory_id)
# Fact files are written into artifacts_dir/fact_cache/
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
for host in hosts: for host in hosts:
filepath = os.path.join(fact_cache_dir, host.name) filepath = os.path.join(fact_cache, host.name)
assert os.path.exists(filepath) assert os.path.exists(filepath)
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, 'r') as f:
assert json.load(f) == host.ansible_facts assert f.read() == json.dumps(host.ansible_facts)
assert os.path.getmtime(filepath) <= last_modified
def test_fact_cache_with_invalid_path_traversal(tmpdir): def test_fact_cache_with_invalid_path_traversal(tmpdir):
@@ -59,84 +51,64 @@ def test_fact_cache_with_invalid_path_traversal(tmpdir):
ansible_facts={"a": 1, "b": 2}, ansible_facts={"a": 1, "b": 2},
), ),
] ]
artifacts_dir = tmpdir.mkdir("artifacts")
inventory_id = 42
start_fact_cache(hosts, artifacts_dir=str(artifacts_dir), timeout=0, inventory_id=inventory_id) fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=0)
# Fact cache directory (safe location) # a file called "foo" should _not_ be written outside the facts dir
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache') assert os.listdir(os.path.join(fact_cache, '..')) == ['facts']
# The bad host name should not produce a file
assert not os.path.exists(os.path.join(fact_cache_dir, '../foo'))
# Make sure the fact_cache dir exists and is still empty
assert os.listdir(fact_cache_dir) == []
def test_start_job_fact_cache_past_timeout(hosts, tmpdir): def test_start_job_fact_cache_past_timeout(hosts, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=2) # the hosts fixture was modified 5s ago, which is more than 2s
last_modified = start_fact_cache(hosts, fact_cache, timeout=2)
assert last_modified is None
for host in hosts: for host in hosts:
assert not os.path.exists(os.path.join(fact_cache, host.name)) assert not os.path.exists(os.path.join(fact_cache, host.name))
ret = start_fact_cache(hosts, fact_cache, timeout=2)
assert ret is None
def test_start_job_fact_cache_within_timeout(hosts, tmpdir): def test_start_job_fact_cache_within_timeout(hosts, tmpdir):
artifacts_dir = tmpdir.mkdir("artifacts")
# The hosts fixture was modified 5s ago, which is less than 7s
start_fact_cache(hosts, str(artifacts_dir), timeout=7)
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
for host in hosts:
filepath = os.path.join(fact_cache_dir, host.name)
assert os.path.exists(filepath)
with open(filepath, 'r') as f:
assert json.load(f) == host.ansible_facts
def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=0) # the hosts fixture was modified 5s ago, which is less than 7s
last_modified = start_fact_cache(hosts, fact_cache, timeout=7)
assert last_modified
bulk_update = mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id') for host in hosts:
assert os.path.exists(os.path.join(fact_cache, host.name))
# Mock the os.path.exists behavior for host deletion
# Let's assume the fact file for hosts[1] is missing.
mocker.patch('os.path.exists', side_effect=lambda path: hosts[1].name not in path)
# Simulate one host's fact file getting deleted manually def test_finish_job_fact_cache_with_existing_data(hosts, mocker, tmpdir, ref_time):
host_to_delete_filepath = os.path.join(fact_cache, hosts[1].name) fact_cache = os.path.join(tmpdir, 'facts')
last_modified = start_fact_cache(hosts, fact_cache, timeout=0)
# Simulate the file being removed by checking existence first, to avoid FileNotFoundError bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
if os.path.exists(host_to_delete_filepath):
os.remove(host_to_delete_filepath)
finish_fact_cache(fact_cache) ansible_facts_new = {"foo": "bar"}
filepath = os.path.join(fact_cache, hosts[1].name)
with open(filepath, 'w') as f:
f.write(json.dumps(ansible_facts_new))
f.flush()
# I feel kind of gross about calling `os.utime` by hand, but I noticed
# that in our container-based dev environment, the resolution for
# `os.stat()` after a file write was over a second, and I don't want to put
# a sleep() in this test
new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time))
# Simulate side effects that would normally be applied during bulk update finish_fact_cache(hosts, fact_cache, last_modified)
hosts[1].ansible_facts = {}
hosts[1].ansible_facts_modified = now()
# Verify facts are preserved for hosts with valid cache files
for host in (hosts[0], hosts[2], hosts[3]): for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2} assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time assert host.ansible_facts_modified == ref_time
assert hosts[1].ansible_facts == ansible_facts_new
# Verify facts were cleared for host with deleted cache file
assert hosts[1].ansible_facts == {}
assert hosts[1].ansible_facts_modified > ref_time assert hosts[1].ansible_facts_modified > ref_time
bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])
# Current implementation skips the call entirely if hosts_to_update == []
bulk_update.assert_not_called()
def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir): def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=0) last_modified = start_fact_cache(hosts, fact_cache, timeout=0)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update') bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
@@ -148,6 +120,23 @@ def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
new_modification_time = time.time() + 3600 new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time)) os.utime(filepath, (new_modification_time, new_modification_time))
finish_fact_cache(fact_cache) finish_fact_cache(hosts, fact_cache, last_modified)
bulk_update.assert_not_called() bulk_update.assert_not_called()
def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts')
last_modified = start_fact_cache(hosts, fact_cache, timeout=0)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
os.remove(os.path.join(fact_cache, hosts[1].name))
finish_fact_cache(hosts, fact_cache, last_modified)
for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time
assert hosts[1].ansible_facts == {}
assert hosts[1].ansible_facts_modified > ref_time
bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])

View File

@@ -561,7 +561,7 @@ class TestBFSNodesToRun:
assert set([nodes[1], nodes[2]]) == set(g.bfs_nodes_to_run()) assert set([nodes[1], nodes[2]]) == set(g.bfs_nodes_to_run())
@pytest.mark.xfail(reason="Run manually to re-generate doc images") @pytest.mark.skip(reason="Run manually to re-generate doc images")
class TestDocsExample: class TestDocsExample:
@pytest.fixture @pytest.fixture
def complex_dag(self, wf_node_generator): def complex_dag(self, wf_node_generator):

View File

@@ -1,190 +0,0 @@
# -*- coding: utf-8 -*-
import os
import tempfile
import shutil
import pytest
from unittest import mock
from awx.main.models import (
Inventory,
Host,
)
from django.utils.timezone import now
from django.db.models.query import QuerySet
from awx.main.models import (
Job,
Organization,
Project,
)
from awx.main.tasks import jobs
@pytest.fixture
def private_data_dir():
private_data = tempfile.mkdtemp(prefix='awx_')
for subfolder in ('inventory', 'env'):
runner_subfolder = os.path.join(private_data, subfolder)
os.makedirs(runner_subfolder, exist_ok=True)
yield private_data
shutil.rmtree(private_data, True)
@mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
# Create mocked inventory and host queryset
inventory = mock.MagicMock(spec=Inventory, pk=1)
host1 = mock.MagicMock(spec=Host, id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
host2 = mock.MagicMock(spec=Host, id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
# Mock hosts queryset
hosts = [host1, host2]
qs_hosts = mock.MagicMock(spec=QuerySet)
qs_hosts._result_cache = hosts
qs_hosts.only.return_value = hosts
qs_hosts.count.side_effect = lambda: len(qs_hosts._result_cache)
inventory.hosts = qs_hosts
# Create mocked job object
org = mock.MagicMock(spec=Organization, pk=1)
proj = mock.MagicMock(spec=Project, pk=1, organization=org)
job = mock.MagicMock(
spec=Job,
use_fact_cache=True,
project=proj,
organization=org,
job_slice_number=1,
job_slice_count=1,
inventory=inventory,
execution_environment=execution_environment,
)
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
job.job_env.get = mock.MagicMock(return_value=private_data_dir)
# Mock RunJob task
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
# Run pre_run_hook
task.facts_write_time = task.pre_run_hook(job, private_data_dir)
# Add a third mocked host
host3 = mock.MagicMock(spec=Host, id=3, name='host3', ansible_facts={"added": True}, ansible_facts_modified=now(), inventory=inventory)
qs_hosts._result_cache.append(host3)
assert inventory.hosts.count() == 3
# Run post_run_hook
task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True)
task.post_run_hook(job, "success")
# Verify final host facts
assert qs_hosts._result_cache[2].ansible_facts == {"added": True}
@mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
@mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
# Fully mocked inventory
mock_inventory = mock.MagicMock(spec=Inventory)
# Create 999 mocked Host instances
hosts = []
for i in range(999):
host = mock.MagicMock(spec=Host)
host.id = i
host.name = f'host{i}'
host.ansible_facts = {"a": 1, "b": 2}
host.ansible_facts_modified = now()
host.inventory = mock_inventory
hosts.append(host)
# Mock inventory.hosts behavior
mock_qs_hosts = mock.MagicMock()
mock_qs_hosts.only.return_value = hosts
mock_qs_hosts.count.return_value = 999
mock_inventory.hosts = mock_qs_hosts
# Mock Organization and Project
org = mock.MagicMock(spec=Organization)
proj = mock.MagicMock(spec=Project)
proj.organization = org
# Mock job object
job = mock.MagicMock(spec=Job)
job.use_fact_cache = True
job.project = proj
job.organization = org
job.job_slice_number = 1
job.job_slice_count = 3
job.execution_environment = execution_environment
job.inventory = mock_inventory
job.job_env.get.return_value = private_data_dir
# Bind actual method for host filtering
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
# Mock task instance
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
# Call pre_run_hook
task.facts_write_time = task.pre_run_hook(job, private_data_dir)
# Simulate one host deletion
hosts.pop(1)
mock_qs_hosts.count.return_value = 998
# Call post_run_hook
task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True)
task.post_run_hook(job, "success")
# Assert that ansible_facts were preserved
for host in hosts:
assert host.ansible_facts == {"a": 1, "b": 2}
# Add expected failure cases
failures = []
for host in hosts:
try:
assert host.ansible_facts == {"a": 1, "b": 2, "unexpected_key": "bad"}
except AssertionError:
failures.append(f"Host named {host.name} has facts {host.ansible_facts}")
assert len(failures) > 0, f"Failures occurred for the following hosts: {failures}"
@mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
@mock.patch('awx.main.tasks.facts.settings')
def test_invalid_host_facts(mock_facts_settings, bulk_update_sorted_by_id, private_data_dir, execution_environment):
inventory = Inventory(pk=1)
mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory)
mock_inventory._state = mock.MagicMock()
hosts = [
Host(id=0, name='host0', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=mock_inventory),
Host(id=1, name='host1', ansible_facts={"a": 1, "b": 2, "unexpected_key": "bad"}, ansible_facts_modified=now(), inventory=mock_inventory),
]
mock_inventory.hosts = hosts
failures = []
for host in mock_inventory.hosts:
assert "a" in host.ansible_facts
if "unexpected_key" in host.ansible_facts:
failures.append(host.name)
mock_facts_settings.SOME_SETTING = True
bulk_update_sorted_by_id(Host, mock_inventory.hosts, fields=['ansible_facts'])
with pytest.raises(pytest.fail.Exception):
if failures:
pytest.fail(f" {len(failures)} facts cleared failures : {','.join(failures)}")

Some files were not shown because too many files have changed in this diff Show More