Compare commits

..

2 Commits

Author SHA1 Message Date
Peter Braun
543d3f940b update licenses and embedded sources 2025-04-15 11:14:10 +02:00
Peter Braun
ee7edb9179 update sqlparse dependency 2025-04-14 23:21:16 +02:00
137 changed files with 1070 additions and 5423 deletions

View File

@@ -2,7 +2,7 @@
codecov:
notify:
after_n_builds: 9 # Number of test matrix+lint jobs uploading coverage
after_n_builds: 6 # Number of test matrix+lint jobs uploading coverage
wait_for_ci: false
require_ci_to_pass: false

View File

@@ -17,23 +17,6 @@ exclude_also =
[run]
branch = True
# NOTE: `disable_warnings` is needed when `pytest-cov` runs in tandem
# NOTE: with `pytest-xdist`. These warnings are false negative in this
# NOTE: context.
#
# NOTE: It's `coveragepy` that emits the warnings and previously they
# NOTE: wouldn't get on the radar of `pytest`'s `filterwarnings`
# NOTE: mechanism. This changed, however, with `pytest >= 8.4`. And
# NOTE: since we set `filterwarnings = error`, those warnings are being
# NOTE: raised as exceptions, cascading into `pytest`'s internals and
# NOTE: causing tracebacks and crashes of the test sessions.
#
# Ref:
# * https://github.com/pytest-dev/pytest-cov/issues/693
# * https://github.com/pytest-dev/pytest-cov/pull/695
# * https://github.com/pytest-dev/pytest-cov/pull/696
disable_warnings =
module-not-measured
omit =
awx/main/migrations/*
awx/settings/defaults.py

View File

@@ -4,8 +4,7 @@
<!---
If you are fixing an existing issue, please include "related #nnn" in your
commit message and your description; but you should still explain what
the change does. Also please make sure that if this PR has an attached JIRA, put AAP-<number>
in as the first entry for your PR title.
the change does.
-->
##### ISSUE TYPE
@@ -23,6 +22,11 @@ in as the first entry for your PR title.
- Docs
- Other
##### AWX VERSION
<!--- Paste verbatim output from `make VERSION` between quotes below -->
```
```
##### ADDITIONAL INFORMATION

View File

@@ -335,7 +335,6 @@ jobs:
with:
name: coverage-${{ matrix.target-regex.name }}
path: ~/.ansible/collections/ansible_collections/awx/awx/tests/output/coverage/
retention-days: 1
- uses: ./.github/actions/upload_awx_devel_logs
if: always()
@@ -353,7 +352,6 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
show-progress: false
- uses: ./.github/actions/setup-python
@@ -363,12 +361,23 @@ jobs:
- name: Upgrade ansible-core
run: python3 -m pip install --upgrade ansible-core
- name: Download coverage artifacts
- name: Download coverage artifacts A to H
uses: actions/download-artifact@v4
with:
merge-multiple: true
name: coverage-a-h
path: coverage
- name: Download coverage artifacts I to P
uses: actions/download-artifact@v4
with:
name: coverage-i-p
path: coverage
- name: Download coverage artifacts Z to Z
uses: actions/download-artifact@v4
with:
name: coverage-r-z0-9
path: coverage
pattern: coverage-*
- name: Combine coverage
run: |
@@ -386,6 +395,46 @@ jobs:
echo '## AWX Collection Integration Coverage HTML' >> $GITHUB_STEP_SUMMARY
echo 'Download the HTML artifacts to view the coverage report.' >> $GITHUB_STEP_SUMMARY
# This is a huge hack, there's no official action for removing artifacts currently.
# Also ACTIONS_RUNTIME_URL and ACTIONS_RUNTIME_TOKEN aren't available in normal run
# steps, so we have to use github-script to get them.
#
# The advantage of doing this, though, is that we save on artifact storage space.
- name: Get secret artifact runtime URL
uses: actions/github-script@v6
id: get-runtime-url
with:
result-encoding: string
script: |
const { ACTIONS_RUNTIME_URL } = process.env;
return ACTIONS_RUNTIME_URL;
- name: Get secret artifact runtime token
uses: actions/github-script@v6
id: get-runtime-token
with:
result-encoding: string
script: |
const { ACTIONS_RUNTIME_TOKEN } = process.env;
return ACTIONS_RUNTIME_TOKEN;
- name: Remove intermediary artifacts
env:
ACTIONS_RUNTIME_URL: ${{ steps.get-runtime-url.outputs.result }}
ACTIONS_RUNTIME_TOKEN: ${{ steps.get-runtime-token.outputs.result }}
run: |
echo "::add-mask::${ACTIONS_RUNTIME_TOKEN}"
artifacts=$(
curl -H "Authorization: Bearer $ACTIONS_RUNTIME_TOKEN" \
${ACTIONS_RUNTIME_URL}_apis/pipelines/workflows/${{ github.run_id }}/artifacts?api-version=6.0-preview \
| jq -r '.value | .[] | select(.name | startswith("coverage-")) | .url'
)
for artifact in $artifacts; do
curl -i -X DELETE -H "Accept: application/json;api-version=6.0-preview" -H "Authorization: Bearer $ACTIONS_RUNTIME_TOKEN" "$artifact"
done
- name: Upload coverage report as artifact
uses: actions/upload-artifact@v4
with:

2
.gitignore vendored
View File

@@ -150,8 +150,6 @@ use_dev_supervisor.txt
awx/ui/src
awx/ui/build
awx/ui/.ui-built
awx/ui_next
# Docs build stuff
docs/docsite/build/

View File

@@ -19,12 +19,6 @@ COLLECTION_VERSION ?= $(shell $(PYTHON) tools/scripts/scm_version.py | cut -d .
COLLECTION_SANITY_ARGS ?= --docker
# collection unit testing directories
COLLECTION_TEST_DIRS ?= awx_collection/test/awx
# pytest added args to collect coverage
COVERAGE_ARGS ?= --cov --cov-report=xml --junitxml=reports/junit.xml
# pytest test directories
TEST_DIRS ?= awx/main/tests/unit awx/main/tests/functional awx/conf/tests
# pytest args to run tests in parallel
PARALLEL_TESTS ?= -n auto
# collection integration test directories (defaults to all)
COLLECTION_TEST_TARGET ?=
# args for collection install
@@ -315,14 +309,14 @@ black: reports
@chmod +x .git/hooks/pre-commit
genschema: reports
$(MAKE) swagger PYTEST_ADDOPTS="--genschema --create-db "
$(MAKE) swagger PYTEST_ARGS="--genschema --create-db "
mv swagger.json schema.json
swagger: reports
@if [ "$(VENV_BASE)" ]; then \
. $(VENV_BASE)/awx/bin/activate; \
fi; \
(set -o pipefail && py.test $(COVERAGE_ARGS) $(PARALLEL_TESTS) awx/conf/tests/functional awx/main/tests/functional/api awx/main/tests/docs | tee reports/$@.report)
(set -o pipefail && py.test --cov --cov-report=xml --junitxml=reports/junit.xml $(PYTEST_ARGS) awx/conf/tests/functional awx/main/tests/functional/api awx/main/tests/docs | tee reports/$@.report)
@if [ "${GITHUB_ACTIONS}" = "true" ]; \
then \
echo 'cov-report-files=reports/coverage.xml' >> "${GITHUB_OUTPUT}"; \
@@ -340,12 +334,14 @@ api-lint:
awx-link:
[ -d "/awx_devel/awx.egg-info" ] || $(PYTHON) /awx_devel/tools/scripts/egg_info_dev
TEST_DIRS ?= awx/main/tests/unit awx/main/tests/functional awx/conf/tests
PYTEST_ARGS ?= -n auto
## Run all API unit tests.
test:
if [ "$(VENV_BASE)" ]; then \
. $(VENV_BASE)/awx/bin/activate; \
fi; \
PYTHONDONTWRITEBYTECODE=1 py.test -p no:cacheprovider $(PARALLEL_TESTS) $(TEST_DIRS)
PYTHONDONTWRITEBYTECODE=1 py.test -p no:cacheprovider $(PYTEST_ARGS) $(TEST_DIRS)
cd awxkit && $(VENV_BASE)/awx/bin/tox -re py3
awx-manage check_migrations --dry-run --check -n 'missing_migration_file'
@@ -354,7 +350,7 @@ live_test:
## Run all API unit tests with coverage enabled.
test_coverage:
$(MAKE) test PYTEST_ADDOPTS="--create-db $(COVERAGE_ARGS)"
$(MAKE) test PYTEST_ARGS="--create-db --cov --cov-report=xml --junitxml=reports/junit.xml"
@if [ "${GITHUB_ACTIONS}" = "true" ]; \
then \
echo 'cov-report-files=awxkit/coverage.xml,reports/coverage.xml' >> "${GITHUB_OUTPUT}"; \
@@ -362,7 +358,7 @@ test_coverage:
fi
test_migrations:
PYTHONDONTWRITEBYTECODE=1 py.test -p no:cacheprovider --migrations -m migration_test --create-db $(PARALLEL_TESTS) $(COVERAGE_ARGS) $(TEST_DIRS)
PYTHONDONTWRITEBYTECODE=1 py.test -p no:cacheprovider --migrations -m migration_test --create-db --cov=awx --cov-report=xml --junitxml=reports/junit.xml $(PYTEST_ARGS) $(TEST_DIRS)
@if [ "${GITHUB_ACTIONS}" = "true" ]; \
then \
echo 'cov-report-files=reports/coverage.xml' >> "${GITHUB_OUTPUT}"; \
@@ -380,7 +376,7 @@ test_collection:
fi && \
if ! [ -x "$(shell command -v ansible-playbook)" ]; then pip install ansible-core; fi
ansible --version
py.test $(COLLECTION_TEST_DIRS) $(COVERAGE_ARGS) -v
py.test $(COLLECTION_TEST_DIRS) --cov --cov-report=xml --junitxml=reports/junit.xml -v
@if [ "${GITHUB_ACTIONS}" = "true" ]; \
then \
echo 'cov-report-files=reports/coverage.xml' >> "${GITHUB_OUTPUT}"; \

View File

@@ -7,7 +7,6 @@ import json
import logging
import re
import yaml
import urllib.parse
from collections import Counter, OrderedDict
from datetime import timedelta
from uuid import uuid4
@@ -117,7 +116,6 @@ from awx.main.utils import (
from awx.main.utils.filters import SmartFilter
from awx.main.utils.plugins import load_combined_inventory_source_options
from awx.main.utils.named_url_graph import reset_counters
from awx.main.utils.inventory_vars import update_group_variables
from awx.main.scheduler.task_manager_models import TaskManagerModels
from awx.main.redact import UriCleaner, REPLACE_STR
from awx.main.signals import update_inventory_computed_fields
@@ -734,22 +732,7 @@ class EmptySerializer(serializers.Serializer):
pass
class OpaQueryPathMixin(serializers.Serializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def validate_opa_query_path(self, value):
# Decode the URL and re-encode it
decoded_value = urllib.parse.unquote(value)
re_encoded_value = urllib.parse.quote(decoded_value, safe='/')
if value != re_encoded_value:
raise serializers.ValidationError(_("The URL must be properly encoded."))
return value
class UnifiedJobTemplateSerializer(BaseSerializer, OpaQueryPathMixin):
class UnifiedJobTemplateSerializer(BaseSerializer):
# As a base serializer, the capabilities prefetch is not used directly,
# instead they are derived from the Workflow Job Template Serializer and the Job Template Serializer, respectively.
capabilities_prefetch = []
@@ -1182,12 +1165,12 @@ class UserActivityStreamSerializer(UserSerializer):
fields = ('*', '-is_system_auditor')
class OrganizationSerializer(BaseSerializer, OpaQueryPathMixin):
class OrganizationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete']
class Meta:
model = Organization
fields = ('*', 'max_hosts', 'custom_virtualenv', 'default_environment', 'opa_query_path')
fields = ('*', 'max_hosts', 'custom_virtualenv', 'default_environment')
read_only_fields = ('*', 'custom_virtualenv')
def get_related(self, obj):
@@ -1541,7 +1524,7 @@ class LabelsListMixin(object):
return res
class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables, OpaQueryPathMixin):
class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables):
show_capabilities = ['edit', 'delete', 'adhoc', 'copy']
capabilities_prefetch = ['admin', 'adhoc', {'copy': 'organization.inventory_admin'}]
@@ -1562,7 +1545,6 @@ class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables, OpaQuery
'inventory_sources_with_failures',
'pending_deletion',
'prevent_instance_group_fallback',
'opa_query_path',
)
def get_related(self, obj):
@@ -1632,68 +1614,8 @@ class InventorySerializer(LabelsListMixin, BaseSerializerWithVariables, OpaQuery
if kind == 'smart' and not host_filter:
raise serializers.ValidationError({'host_filter': _('Smart inventories must specify host_filter')})
return super(InventorySerializer, self).validate(attrs)
@staticmethod
def _update_variables(variables, inventory_id):
"""
Update the inventory variables of the 'all'-group.
The variables field contains vars from the inventory dialog, hence
representing the "all"-group variables.
Since this is not an update from an inventory source, we update the
variables when the inventory details form is saved.
A user edit on the inventory variables is considered a reset of the
variables update history. Particularly if the user removes a variable by
editing the inventory variables field, the variable is not supposed to
reappear with a value from a previous inventory source update.
We achieve this by forcing `reset=True` on such an update.
As a side-effect, variables which have been set by source updates and
have survived a user-edit (i.e. they have not been deleted from the
variables field) will be assumed to originate from the user edit and are
thus no longer deleted from the inventory when they are removed from
their original source!
Note that we use the inventory source id -1 for user-edit updates
because a regular inventory source cannot have an id of -1 since
PostgreSQL assigns pk's starting from 1 (if this assumption doesn't hold
true, we have to assign another special value for invsrc_id).
:param str variables: The variables as plain text in yaml or json
format.
:param int inventory_id: The primary key of the related inventory
object.
"""
variables_dict = parse_yaml_or_json(variables, silent_failure=False)
logger.debug(f"InventorySerializer._update_variables: {inventory_id=} {variables_dict=}, {variables=}")
update_group_variables(
group_id=None, # `None` denotes the 'all' group (which doesn't have a pk).
newvars=variables_dict,
dbvars=None,
invsrc_id=-1,
inventory_id=inventory_id,
reset=True,
)
def create(self, validated_data):
"""Called when a new inventory has to be created."""
logger.debug(f"InventorySerializer.create({validated_data=}) >>>>")
obj = super().create(validated_data)
self._update_variables(validated_data.get("variables") or "", obj.id)
return obj
def update(self, obj, validated_data):
"""Called when an existing inventory is updated."""
logger.debug(f"InventorySerializer.update({validated_data=}) >>>>")
obj = super().update(obj, validated_data)
self._update_variables(validated_data.get("variables") or "", obj.id)
return obj
class ConstructedFieldMixin(serializers.Field):
def get_attribute(self, instance):
@@ -1983,12 +1905,10 @@ class GroupSerializer(BaseSerializerWithVariables):
return res
def validate(self, attrs):
# Do not allow the group name to conflict with an existing host name.
name = force_str(attrs.get('name', self.instance and self.instance.name or ''))
inventory = attrs.get('inventory', self.instance and self.instance.inventory or '')
if Host.objects.filter(name=name, inventory=inventory).exists():
raise serializers.ValidationError(_('A Host with that name already exists.'))
#
return super(GroupSerializer, self).validate(attrs)
def validate_name(self, value):
@@ -3327,7 +3247,6 @@ class JobTemplateSerializer(JobTemplateMixin, UnifiedJobTemplateSerializer, JobO
'webhook_service',
'webhook_credential',
'prevent_instance_group_fallback',
'opa_query_path',
)
read_only_fields = ('*', 'custom_virtualenv')

View File

@@ -10,7 +10,7 @@ from awx.api.generics import APIView, Response
from awx.api.permissions import AnalyticsPermission
from awx.api.versioning import reverse
from awx.main.utils import get_awx_version
from awx.main.utils.analytics_proxy import OIDCClient
from awx.main.utils.analytics_proxy import OIDCClient, DEFAULT_OIDC_TOKEN_ENDPOINT
from rest_framework import status
from collections import OrderedDict
@@ -202,16 +202,10 @@ class AnalyticsGenericView(APIView):
if method not in ["GET", "POST", "OPTIONS"]:
return self._error_response(ERROR_UNSUPPORTED_METHOD, method, remote=False, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
url = self._get_analytics_url(request.path)
using_subscriptions_credentials = False
try:
rh_user = getattr(settings, 'REDHAT_USERNAME', None)
rh_password = getattr(settings, 'REDHAT_PASSWORD', None)
if not (rh_user and rh_password):
rh_user = self._get_setting('SUBSCRIPTIONS_CLIENT_ID', None, ERROR_MISSING_USER)
rh_password = self._get_setting('SUBSCRIPTIONS_CLIENT_SECRET', None, ERROR_MISSING_PASSWORD)
using_subscriptions_credentials = True
client = OIDCClient(rh_user, rh_password)
rh_user = self._get_setting('REDHAT_USERNAME', None, ERROR_MISSING_USER)
rh_password = self._get_setting('REDHAT_PASSWORD', None, ERROR_MISSING_PASSWORD)
client = OIDCClient(rh_user, rh_password, DEFAULT_OIDC_TOKEN_ENDPOINT, ['api.console'])
response = client.make_request(
method,
url,
@@ -222,17 +216,17 @@ class AnalyticsGenericView(APIView):
timeout=(31, 31),
)
except requests.RequestException:
# subscriptions credentials are not valid for basic auth, so just return 401
if using_subscriptions_credentials:
response = Response(status=status.HTTP_401_UNAUTHORIZED)
else:
logger.error("Automation Analytics API request failed, trying base auth method")
response = self._base_auth_request(request, method, url, rh_user, rh_password, headers)
logger.error("Automation Analytics API request failed, trying base auth method")
response = self._base_auth_request(request, method, url, rh_user, rh_password, headers)
except MissingSettings:
rh_user = self._get_setting('SUBSCRIPTIONS_USERNAME', None, ERROR_MISSING_USER)
rh_password = self._get_setting('SUBSCRIPTIONS_PASSWORD', None, ERROR_MISSING_PASSWORD)
response = self._base_auth_request(request, method, url, rh_user, rh_password, headers)
#
# Missing or wrong user/pass
#
if response.status_code == status.HTTP_401_UNAUTHORIZED:
text = response.get('text', '').rstrip("\n")
text = (response.text or '').rstrip("\n")
return self._error_response(ERROR_UNAUTHORIZED, text, remote=True, remote_status_code=response.status_code)
#
# Not found, No entitlement or No data in Analytics

View File

@@ -32,7 +32,6 @@ from awx.api.versioning import URLPathVersioning, reverse, drf_reverse
from awx.main.constants import PRIVILEGE_ESCALATION_METHODS
from awx.main.models import Project, Organization, Instance, InstanceGroup, JobTemplate
from awx.main.utils import set_environ
from awx.main.utils.analytics_proxy import TokenError
from awx.main.utils.licensing import get_licenser
logger = logging.getLogger('awx.api.views.root')
@@ -177,21 +176,19 @@ class ApiV2SubscriptionView(APIView):
def post(self, request):
data = request.data.copy()
if data.get('subscriptions_client_secret') == '$encrypted$':
data['subscriptions_client_secret'] = settings.SUBSCRIPTIONS_CLIENT_SECRET
if data.get('subscriptions_password') == '$encrypted$':
data['subscriptions_password'] = settings.SUBSCRIPTIONS_PASSWORD
try:
user, pw = data.get('subscriptions_client_id'), data.get('subscriptions_client_secret')
user, pw = data.get('subscriptions_username'), data.get('subscriptions_password')
with set_environ(**settings.AWX_TASK_ENV):
validated = get_licenser().validate_rh(user, pw)
if user:
settings.SUBSCRIPTIONS_CLIENT_ID = data['subscriptions_client_id']
settings.SUBSCRIPTIONS_USERNAME = data['subscriptions_username']
if pw:
settings.SUBSCRIPTIONS_CLIENT_SECRET = data['subscriptions_client_secret']
settings.SUBSCRIPTIONS_PASSWORD = data['subscriptions_password']
except Exception as exc:
msg = _("Invalid Subscription")
if isinstance(exc, TokenError) or (
isinstance(exc, requests.exceptions.HTTPError) and getattr(getattr(exc, 'response', None), 'status_code', None) == 401
):
if isinstance(exc, requests.exceptions.HTTPError) and getattr(getattr(exc, 'response', None), 'status_code', None) == 401:
msg = _("The provided credentials are invalid (HTTP 401).")
elif isinstance(exc, requests.exceptions.ProxyError):
msg = _("Unable to connect to proxy server.")
@@ -218,12 +215,12 @@ class ApiV2AttachView(APIView):
def post(self, request):
data = request.data.copy()
subscription_id = data.get('subscription_id', None)
if not subscription_id:
return Response({"error": _("No subscription ID provided.")}, status=status.HTTP_400_BAD_REQUEST)
user = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None)
pw = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None)
if subscription_id and user and pw:
pool_id = data.get('pool_id', None)
if not pool_id:
return Response({"error": _("No subscription pool ID provided.")}, status=status.HTTP_400_BAD_REQUEST)
user = getattr(settings, 'SUBSCRIPTIONS_USERNAME', None)
pw = getattr(settings, 'SUBSCRIPTIONS_PASSWORD', None)
if pool_id and user and pw:
data = request.data.copy()
try:
with set_environ(**settings.AWX_TASK_ENV):
@@ -242,7 +239,7 @@ class ApiV2AttachView(APIView):
logger.exception(smart_str(u"Invalid subscription submitted."), extra=dict(actor=request.user.username))
return Response({"error": msg}, status=status.HTTP_400_BAD_REQUEST)
for sub in validated:
if sub['subscription_id'] == subscription_id:
if sub['pool_id'] == pool_id:
sub['valid_key'] = True
settings.LICENSE = sub
return Response(sub)

View File

@@ -10,7 +10,7 @@ from django.core.validators import URLValidator, _lazy_re_compile
from django.utils.translation import gettext_lazy as _
# Django REST Framework
from rest_framework.fields import BooleanField, CharField, ChoiceField, DictField, DateTimeField, EmailField, IntegerField, ListField, FloatField # noqa
from rest_framework.fields import BooleanField, CharField, ChoiceField, DictField, DateTimeField, EmailField, IntegerField, ListField # noqa
from rest_framework.serializers import PrimaryKeyRelatedField # noqa
# AWX
@@ -207,8 +207,7 @@ class URLField(CharField):
if self.allow_plain_hostname:
try:
url_parts = urlparse.urlsplit(value)
looks_like_ipv6 = bool(url_parts.netloc and url_parts.netloc.startswith('[') and url_parts.netloc.endswith(']'))
if not looks_like_ipv6 and url_parts.hostname and '.' not in url_parts.hostname:
if url_parts.hostname and '.' not in url_parts.hostname:
netloc = '{}.local'.format(url_parts.hostname)
if url_parts.port:
netloc = '{}:{}'.format(netloc, url_parts.port)

View File

@@ -27,5 +27,5 @@ def _migrate_setting(apps, old_key, new_key, encrypted=False):
def prefill_rh_credentials(apps, schema_editor):
_migrate_setting(apps, 'REDHAT_USERNAME', 'SUBSCRIPTIONS_CLIENT_ID', encrypted=False)
_migrate_setting(apps, 'REDHAT_PASSWORD', 'SUBSCRIPTIONS_CLIENT_SECRET', encrypted=True)
_migrate_setting(apps, 'REDHAT_USERNAME', 'SUBSCRIPTIONS_USERNAME', encrypted=False)
_migrate_setting(apps, 'REDHAT_PASSWORD', 'SUBSCRIPTIONS_PASSWORD', encrypted=True)

View File

@@ -38,7 +38,6 @@ class SettingsRegistry(object):
if setting in self._registry:
raise ImproperlyConfigured('Setting "{}" is already registered.'.format(setting))
category = kwargs.setdefault('category', None)
kwargs.setdefault('required', False) # No setting is ordinarily required
category_slug = kwargs.setdefault('category_slug', slugify(category or '') or None)
if category_slug in {'all', 'changed', 'user-defaults'}:
raise ImproperlyConfigured('"{}" is a reserved category slug.'.format(category_slug))

View File

@@ -128,41 +128,3 @@ class TestURLField:
else:
with pytest.raises(ValidationError):
field.run_validators(url)
@pytest.mark.parametrize(
"url, expect_error",
[
("https://[1:2:3]", True),
("http://[1:2:3]", True),
("https://[2001:db8:3333:4444:5555:6666:7777:8888", True),
("https://2001:db8:3333:4444:5555:6666:7777:8888", True),
("https://[2001:db8:3333:4444:5555:6666:7777:8888]", False),
("https://[::1]", False),
("https://[::]", False),
("https://[2001:db8::1]", False),
("https://[2001:db8:0:0:0:0:1:1]", False),
("https://[fe80::2%eth0]", True), # ipv6 scope identifier
("https://[fe80:0:0:0:200:f8ff:fe21:67cf]", False),
("https://[::ffff:192.168.1.10]", False),
("https://[0:0:0:0:0:ffff:c000:0201]", False),
("https://[2001:0db8:000a:0001:0000:0000:0000:0000]", False),
("https://[2001:db8:a:1::]", False),
("https://[ff02::1]", False),
("https://[ff02:0:0:0:0:0:0:1]", False),
("https://[fc00::1]", False),
("https://[fd12:3456:789a:1::1]", False),
("https://[2001:db8::abcd:ef12:3456:7890]", False),
("https://[2001:db8:0000:abcd:0000:ef12:0000:3456]", False),
("https://[::ffff:10.0.0.1]", False),
("https://[2001:db8:cafe::]", False),
("https://[2001:db8:cafe:0:0:0:0:0]", False),
("https://[fe80::210:f3ff:fedf:4567%3]", True), # ipv6 scope identifier, numerical interface
],
)
def test_ipv6_urls(self, url, expect_error):
field = URLField()
if expect_error:
with pytest.raises(ValidationError, match="Enter a valid URL"):
field.run_validators(url)
else:
field.run_validators(url)

View File

@@ -3,13 +3,13 @@ import logging
# AWX
from awx.main.analytics.subsystem_metrics import DispatcherMetrics, CallbackReceiverMetrics
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler')
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def send_subsystem_metrics():
DispatcherMetrics().send_metrics()
CallbackReceiverMetrics().send_metrics()

View File

@@ -142,7 +142,7 @@ def config(since, **kwargs):
return {
'platform': {
'system': platform.system(),
'dist': (distro.name(), distro.version(), distro.codename()),
'dist': distro.linux_distribution(),
'release': platform.release(),
'type': install_type,
},

View File

@@ -22,7 +22,7 @@ from ansible_base.lib.utils.db import advisory_lock
from awx.main.models import Job
from awx.main.access import access_registry
from awx.main.utils import get_awx_http_client_headers, set_environ, datetime_hook
from awx.main.utils.analytics_proxy import OIDCClient
from awx.main.utils.analytics_proxy import OIDCClient, DEFAULT_OIDC_TOKEN_ENDPOINT
__all__ = ['register', 'gather', 'ship']
@@ -186,7 +186,7 @@ def gather(dest=None, module=None, subset=None, since=None, until=None, collecti
if not (
settings.AUTOMATION_ANALYTICS_URL
and ((settings.REDHAT_USERNAME and settings.REDHAT_PASSWORD) or (settings.SUBSCRIPTIONS_CLIENT_ID and settings.SUBSCRIPTIONS_CLIENT_SECRET))
and ((settings.REDHAT_USERNAME and settings.REDHAT_PASSWORD) or (settings.SUBSCRIPTIONS_USERNAME and settings.SUBSCRIPTIONS_PASSWORD))
):
logger.log(log_level, "Not gathering analytics, configuration is invalid. Use --dry-run to gather locally without sending.")
return None
@@ -324,10 +324,10 @@ def gather(dest=None, module=None, subset=None, since=None, until=None, collecti
settings.AUTOMATION_ANALYTICS_LAST_ENTRIES = json.dumps(last_entries, cls=DjangoJSONEncoder)
if collection_type != 'dry-run':
for fpath in tarfiles:
if os.path.exists(fpath):
os.remove(fpath)
if succeeded:
for fpath in tarfiles:
if os.path.exists(fpath):
os.remove(fpath)
with disable_activity_stream():
if not settings.AUTOMATION_ANALYTICS_LAST_GATHER or until > settings.AUTOMATION_ANALYTICS_LAST_GATHER:
# `AUTOMATION_ANALYTICS_LAST_GATHER` is set whether collection succeeds or fails;
@@ -368,20 +368,8 @@ def ship(path):
logger.error('AUTOMATION_ANALYTICS_URL is not set')
return False
rh_id = getattr(settings, 'REDHAT_USERNAME', None)
rh_secret = getattr(settings, 'REDHAT_PASSWORD', None)
if not (rh_id and rh_secret):
rh_id = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None)
rh_secret = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None)
if not rh_id:
logger.error('Neither REDHAT_USERNAME nor SUBSCRIPTIONS_CLIENT_ID are set')
return False
if not rh_secret:
logger.error('Neither REDHAT_PASSWORD nor SUBSCRIPTIONS_CLIENT_SECRET are set')
return False
rh_user = getattr(settings, 'REDHAT_USERNAME', None)
rh_password = getattr(settings, 'REDHAT_PASSWORD', None)
with open(path, 'rb') as f:
files = {'file': (os.path.basename(path), f, settings.INSIGHTS_AGENT_MIME)}
@@ -389,13 +377,25 @@ def ship(path):
s.headers = get_awx_http_client_headers()
s.headers.pop('Content-Type')
with set_environ(**settings.AWX_TASK_ENV):
try:
client = OIDCClient(rh_id, rh_secret)
response = client.make_request("POST", url, headers=s.headers, files=files, verify=settings.INSIGHTS_CERT_PATH, timeout=(31, 31))
except requests.RequestException:
logger.error("Automation Analytics API request failed, trying base auth method")
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_id, rh_secret), headers=s.headers, timeout=(31, 31))
if rh_user and rh_password:
try:
client = OIDCClient(rh_user, rh_password, DEFAULT_OIDC_TOKEN_ENDPOINT, ['api.console'])
response = client.make_request("POST", url, headers=s.headers, files=files, verify=settings.INSIGHTS_CERT_PATH, timeout=(31, 31))
except requests.RequestException:
logger.error("Automation Analytics API request failed, trying base auth method")
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_user, rh_password), headers=s.headers, timeout=(31, 31))
elif not rh_user or not rh_password:
logger.info('REDHAT_USERNAME and REDHAT_PASSWORD are not set, using SUBSCRIPTIONS_USERNAME and SUBSCRIPTIONS_PASSWORD')
rh_user = getattr(settings, 'SUBSCRIPTIONS_USERNAME', None)
rh_password = getattr(settings, 'SUBSCRIPTIONS_PASSWORD', None)
if rh_user and rh_password:
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_user, rh_password), headers=s.headers, timeout=(31, 31))
elif not rh_user:
logger.error('REDHAT_USERNAME and SUBSCRIPTIONS_USERNAME are not set')
return False
elif not rh_password:
logger.error('REDHAT_PASSWORD and SUBSCRIPTIONS_USERNAME are not set')
return False
# Accept 2XX status_codes
if response.status_code >= 300:
logger.error('Upload failed with status {}, {}'.format(response.status_code, response.text))

View File

@@ -128,7 +128,6 @@ def metrics():
registry=REGISTRY,
)
LICENSE_EXPIRY = Gauge('awx_license_expiry', 'Time before license expires', registry=REGISTRY)
LICENSE_INSTANCE_TOTAL = Gauge('awx_license_instance_total', 'Total number of managed hosts provided by your license', registry=REGISTRY)
LICENSE_INSTANCE_FREE = Gauge('awx_license_instance_free', 'Number of remaining managed hosts provided by your license', registry=REGISTRY)
@@ -149,7 +148,6 @@ def metrics():
}
)
LICENSE_EXPIRY.set(str(license_info.get('time_remaining', 0)))
LICENSE_INSTANCE_TOTAL.set(str(license_info.get('instance_count', 0)))
LICENSE_INSTANCE_FREE.set(str(license_info.get('free_instances', 0)))

View File

@@ -1,9 +1,6 @@
import os
from dispatcherd.config import setup as dispatcher_setup
from django.apps import AppConfig
from django.db import connection
from django.utils.translation import gettext_lazy as _
from awx.main.utils.common import bypass_in_test, load_all_entry_points_for
from awx.main.utils.migration import is_database_synchronized
@@ -79,28 +76,9 @@ class MainConfig(AppConfig):
cls = entry_point.load()
InventorySourceOptions.injectors[entry_point_name] = cls
def configure_dispatcherd(self):
"""This implements the default configuration for dispatcherd
If running the tasking service like awx-manage run_dispatcher,
some additional config will be applied on top of this.
This configuration provides the minimum such that code can submit
tasks to pg_notify to run those tasks.
"""
from awx.main.dispatch.config import get_dispatcherd_config
if connection.vendor != 'postgresql':
config_dict = get_dispatcherd_config(mock_publish=True)
else:
config_dict = get_dispatcherd_config()
dispatcher_setup(config_dict)
def ready(self):
super().ready()
self.configure_dispatcherd()
"""
Credential loading triggers database operations. There are cases we want to call
awx-manage collectstatic without a database. All management commands invoke the ready() code

View File

@@ -12,7 +12,6 @@ from rest_framework import serializers
from awx.conf import fields, register, register_validate
from awx.main.models import ExecutionEnvironment
from awx.main.constants import SUBSCRIPTION_USAGE_MODEL_UNIQUE_HOSTS
from awx.main.tasks.policy import OPA_AUTH_TYPES
logger = logging.getLogger('awx.main.conf')
@@ -91,6 +90,7 @@ register(
),
category=_('System'),
category_slug='system',
required=False,
)
register(
@@ -124,8 +124,8 @@ register(
allow_blank=True,
encrypted=False,
read_only=False,
label=_('Red Hat Client ID for Analytics'),
help_text=_('Client ID used to send data to Automation Analytics'),
label=_('Red Hat customer username'),
help_text=_('This username is used to send data to Automation Analytics'),
category=_('System'),
category_slug='system',
)
@@ -137,34 +137,34 @@ register(
allow_blank=True,
encrypted=True,
read_only=False,
label=_('Red Hat Client Secret for Analytics'),
help_text=_('Client secret used to send data to Automation Analytics'),
label=_('Red Hat customer password'),
help_text=_('This password is used to send data to Automation Analytics'),
category=_('System'),
category_slug='system',
)
register(
'SUBSCRIPTIONS_CLIENT_ID',
'SUBSCRIPTIONS_USERNAME',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=False,
read_only=False,
label=_('Red Hat Client ID for Subscriptions'),
help_text=_('Client ID used to retrieve subscription and content information'), # noqa
label=_('Red Hat or Satellite username'),
help_text=_('This username is used to retrieve subscription and content information'), # noqa
category=_('System'),
category_slug='system',
)
register(
'SUBSCRIPTIONS_CLIENT_SECRET',
'SUBSCRIPTIONS_PASSWORD',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=True,
read_only=False,
label=_('Red Hat Client Secret for Subscriptions'),
help_text=_('Client secret used to retrieve subscription and content information'), # noqa
label=_('Red Hat or Satellite password'),
help_text=_('This password is used to retrieve subscription and content information'), # noqa
category=_('System'),
category_slug='system',
)
@@ -237,6 +237,7 @@ register(
help_text=_('List of modules allowed to be used by ad-hoc jobs.'),
category=_('Jobs'),
category_slug='jobs',
required=False,
)
register(
@@ -247,6 +248,7 @@ register(
('never', _('Never')),
('template', _('Only On Job Template Definitions')),
],
required=True,
label=_('When can extra variables contain Jinja templates?'),
help_text=_(
'Ansible allows variable substitution via the Jinja2 templating '
@@ -271,6 +273,7 @@ register(
register(
'AWX_ISOLATION_SHOW_PATHS',
field_class=fields.StringListIsolatedPathField,
required=False,
label=_('Paths to expose to isolated jobs'),
help_text=_(
'List of paths that would otherwise be hidden to expose to isolated jobs. Enter one path per line. '
@@ -436,6 +439,7 @@ register(
register(
'AWX_ANSIBLE_CALLBACK_PLUGINS',
field_class=fields.StringListField,
required=False,
label=_('Ansible Callback Plugins'),
help_text=_('List of paths to search for extra callback plugins to be used when running jobs. Enter one path per line.'),
category=_('Jobs'),
@@ -549,6 +553,7 @@ register(
help_text=_('Port on Logging Aggregator to send logs to (if required and not provided in Logging Aggregator).'),
category=_('Logging'),
category_slug='logging',
required=False,
)
register(
'LOG_AGGREGATOR_TYPE',
@@ -570,6 +575,7 @@ register(
help_text=_('Username for external log aggregator (if required; HTTP/s only).'),
category=_('Logging'),
category_slug='logging',
required=False,
)
register(
'LOG_AGGREGATOR_PASSWORD',
@@ -581,6 +587,7 @@ register(
help_text=_('Password or authentication token for external log aggregator (if required; HTTP/s only).'),
category=_('Logging'),
category_slug='logging',
required=False,
)
register(
'LOG_AGGREGATOR_LOGGERS',
@@ -767,6 +774,7 @@ register(
allow_null=True,
category=_('System'),
category_slug='system',
required=False,
hidden=True,
)
register(
@@ -972,124 +980,3 @@ def csrf_trusted_origins_validate(serializer, attrs):
register_validate('system', csrf_trusted_origins_validate)
register(
'OPA_HOST',
field_class=fields.CharField,
label=_('OPA server hostname'),
default='',
help_text=_('The hostname used to connect to the OPA server. If empty, policy enforcement will be disabled.'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
)
register(
'OPA_PORT',
field_class=fields.IntegerField,
label=_('OPA server port'),
default=8181,
help_text=_('The port used to connect to the OPA server. Defaults to 8181.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_SSL',
field_class=fields.BooleanField,
label=_('Use SSL for OPA connection'),
default=False,
help_text=_('Enable or disable the use of SSL to connect to the OPA server. Defaults to false.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_AUTH_TYPE',
field_class=fields.ChoiceField,
label=_('OPA authentication type'),
choices=[OPA_AUTH_TYPES.NONE, OPA_AUTH_TYPES.TOKEN, OPA_AUTH_TYPES.CERTIFICATE],
default=OPA_AUTH_TYPES.NONE,
help_text=_('The authentication type that will be used to connect to the OPA server: "None", "Token", or "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_AUTH_TOKEN',
field_class=fields.CharField,
label=_('OPA authentication token'),
default='',
help_text=_(
'The token for authentication to the OPA server. Required when OPA_AUTH_TYPE is "Token". If an authorization header is defined in OPA_AUTH_CUSTOM_HEADERS, it will be overridden by OPA_AUTH_TOKEN.'
),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
encrypted=True,
)
register(
'OPA_AUTH_CLIENT_CERT',
field_class=fields.CharField,
label=_('OPA client certificate content'),
default='',
help_text=_('The content of the client certificate file for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
)
register(
'OPA_AUTH_CLIENT_KEY',
field_class=fields.CharField,
label=_('OPA client key content'),
default='',
help_text=_('The content of the client key for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
encrypted=True,
)
register(
'OPA_AUTH_CA_CERT',
field_class=fields.CharField,
label=_('OPA CA certificate content'),
default='',
help_text=_('The content of the CA certificate for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".'),
category=('PolicyAsCode'),
category_slug='policyascode',
allow_blank=True,
)
register(
'OPA_AUTH_CUSTOM_HEADERS',
field_class=fields.DictField,
label=_('OPA custom authentication headers'),
default={},
help_text=_('Optional custom headers included in requests to the OPA server. Defaults to empty dictionary ({}).'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_REQUEST_TIMEOUT',
field_class=fields.FloatField,
label=_('OPA request timeout'),
default=1.5,
help_text=_('The number of seconds after which the connection to the OPA server will time out. Defaults to 1.5 seconds.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)
register(
'OPA_REQUEST_RETRIES',
field_class=fields.IntegerField,
label=_('OPA request retry count'),
default=2,
help_text=_('The number of retry attempts for connecting to the OPA server. Default is 2.'),
category=('PolicyAsCode'),
category_slug='policyascode',
)

View File

@@ -77,8 +77,6 @@ LOGGER_BLOCKLIST = (
'awx.main.utils.log',
# loggers that may be called getting logging settings
'awx.conf',
# dispatcherd should only use 1 database connection
'dispatcherd',
)
# Reported version for node seen in receptor mesh but for which capacity check

View File

@@ -1,53 +0,0 @@
from django.conf import settings
from ansible_base.lib.utils.db import get_pg_notify_params
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.pool import get_auto_max_workers
def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False) -> dict:
"""Return a dictionary config for dispatcherd
Parameters:
for_service: if True, include dynamic options needed for running the dispatcher service
this will require database access, you should delay evaluation until after app setup
"""
config = {
"version": 2,
"service": {
"pool_kwargs": {
"min_workers": settings.JOB_EVENT_WORKERS,
"max_workers": get_auto_max_workers(),
},
"main_kwargs": {"node_id": settings.CLUSTER_HOST_ID},
"process_manager_cls": "ForkServerManager",
"process_manager_kwargs": {"preload_modules": ['awx.main.dispatch.hazmat']},
},
"brokers": {
"socket": {"socket_path": settings.DISPATCHERD_DEBUGGING_SOCKFILE},
},
"publish": {"default_control_broker": "socket"},
"worker": {"worker_cls": "awx.main.dispatch.worker.dispatcherd.AWXTaskWorker"},
}
if mock_publish:
config["brokers"]["noop"] = {}
config["publish"]["default_broker"] = "noop"
else:
config["brokers"]["pg_notify"] = {
"config": get_pg_notify_params(),
"sync_connection_factory": "ansible_base.lib.utils.db.psycopg_connection_from_django",
"default_publish_channel": settings.CLUSTER_HOST_ID, # used for debugging commands
}
config["publish"]["default_broker"] = "pg_notify"
if for_service:
config["producers"] = {
"ScheduledProducer": {"task_schedule": settings.DISPATCHER_SCHEDULE},
"OnStartProducer": {"task_list": {"awx.main.tasks.system.dispatch_startup": {}}},
"ControlProducer": {},
}
config["brokers"]["pg_notify"]["channels"] = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()]
return config

View File

@@ -1,36 +0,0 @@
import django
# dispatcherd publisher logic is likely to be used, but needs manual preload
from dispatcherd.brokers import pg_notify # noqa
# Cache may not be initialized until we are in the worker, so preload here
from channels_redis import core # noqa
from awx import prepare_env
from dispatcherd.utils import resolve_callable
prepare_env()
django.setup() # noqa
from django.conf import settings
# Preload all periodic tasks so their imports will be in shared memory
for name, options in settings.CELERYBEAT_SCHEDULE.items():
resolve_callable(options['task'])
# Preload in-line import from tasks
from awx.main.scheduler.kubernetes import PodManager # noqa
from django.core.cache import cache as django_cache
from django.db import connection
connection.close()
django_cache.close()

View File

@@ -7,7 +7,6 @@ import time
import traceback
from datetime import datetime
from uuid import uuid4
import json
import collections
from multiprocessing import Process
@@ -26,10 +25,7 @@ from ansible_base.lib.logging.runtime import log_excess_runtime
from awx.main.models import UnifiedJob
from awx.main.dispatch import reaper
from awx.main.utils.common import get_mem_effective_capacity, get_corrected_memory, get_corrected_cpu, get_cpu_effective_capacity
# ansible-runner
from ansible_runner.utils.capacity import get_mem_in_bytes, get_cpu_count
from awx.main.utils.common import convert_mem_str_to_bytes, get_mem_effective_capacity
if 'run_callback_receiver' in sys.argv:
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
@@ -37,9 +33,6 @@ else:
logger = logging.getLogger('awx.main.dispatch')
RETIRED_SENTINEL_TASK = "[retired]"
class NoOpResultQueue(object):
def put(self, item):
pass
@@ -84,17 +77,11 @@ class PoolWorker(object):
self.queue = MPQueue(queue_size)
self.process = Process(target=target, args=(self.queue, self.finished) + args)
self.process.daemon = True
self.creation_time = time.monotonic()
self.retiring = False
def start(self):
self.process.start()
def put(self, body):
if self.retiring:
uuid = body.get('uuid', 'N/A') if isinstance(body, dict) else 'N/A'
logger.info(f"Worker pid:{self.pid} is retiring. Refusing new task {uuid}.")
raise QueueFull("Worker is retiring and not accepting new tasks") # AutoscalePool.write handles QueueFull
uuid = '?'
if isinstance(body, dict):
if not body.get('uuid'):
@@ -113,11 +100,6 @@ class PoolWorker(object):
"""
self.queue.put('QUIT')
@property
def age(self):
"""Returns the current age of the worker in seconds."""
return time.monotonic() - self.creation_time
@property
def pid(self):
return self.process.pid
@@ -164,8 +146,6 @@ class PoolWorker(object):
# the purpose of self.managed_tasks is to just track internal
# state of which events are *currently* being processed.
logger.warning('Event UUID {} appears to be have been duplicated.'.format(uuid))
if self.retiring:
self.managed_tasks[RETIRED_SENTINEL_TASK] = {'task': RETIRED_SENTINEL_TASK}
@property
def current_task(self):
@@ -281,8 +261,6 @@ class WorkerPool(object):
'{% for w in workers %}'
'. worker[pid:{{ w.pid }}]{% if not w.alive %} GONE exit={{ w.exitcode }}{% endif %}'
' sent={{ w.messages_sent }}'
' age={{ "%.0f"|format(w.age) }}s'
' retiring={{ w.retiring }}'
'{% if w.messages_finished %} finished={{ w.messages_finished }}{% endif %}'
' qsize={{ w.managed_tasks|length }}'
' rss={{ w.mb }}MB'
@@ -329,41 +307,6 @@ class WorkerPool(object):
logger.exception('could not kill {}'.format(worker.pid))
def get_auto_max_workers():
"""Method we normally rely on to get max_workers
Uses almost same logic as Instance.local_health_check
The important thing is to be MORE than Instance.capacity
so that the task-manager does not over-schedule this node
Ideally we would just use the capacity from the database plus reserve workers,
but this poses some bootstrap problems where OCP task containers
register themselves after startup
"""
# Get memory from ansible-runner
total_memory_gb = get_mem_in_bytes()
# This may replace memory calculation with a user override
corrected_memory = get_corrected_memory(total_memory_gb)
# Get same number as max forks based on memory, this function takes memory as bytes
mem_capacity = get_mem_effective_capacity(corrected_memory, is_control_node=True)
# Follow same process for CPU capacity constraint
cpu_count = get_cpu_count()
corrected_cpu = get_corrected_cpu(cpu_count)
cpu_capacity = get_cpu_effective_capacity(corrected_cpu, is_control_node=True)
# Here is what is different from health checks,
auto_max = max(mem_capacity, cpu_capacity)
# add magic number of extra workers to ensure
# we have a few extra workers to run the heartbeat
auto_max += 7
return auto_max
class AutoscalePool(WorkerPool):
"""
An extended pool implementation that automatically scales workers up and
@@ -374,13 +317,22 @@ class AutoscalePool(WorkerPool):
def __init__(self, *args, **kwargs):
self.max_workers = kwargs.pop('max_workers', None)
self.max_worker_lifetime_seconds = kwargs.pop(
'max_worker_lifetime_seconds', getattr(settings, 'WORKER_MAX_LIFETIME_SECONDS', 14400)
) # Default to 4 hours
super(AutoscalePool, self).__init__(*args, **kwargs)
if self.max_workers is None:
self.max_workers = get_auto_max_workers()
settings_absmem = getattr(settings, 'SYSTEM_TASK_ABS_MEM', None)
if settings_absmem is not None:
# There are 1073741824 bytes in a gigabyte. Convert bytes to gigabytes by dividing by 2**30
total_memory_gb = convert_mem_str_to_bytes(settings_absmem) // 2**30
else:
total_memory_gb = (psutil.virtual_memory().total >> 30) + 1 # noqa: round up
# Get same number as max forks based on memory, this function takes memory as bytes
self.max_workers = get_mem_effective_capacity(total_memory_gb * 2**30)
# add magic prime number of extra workers to ensure
# we have a few extra workers to run the heartbeat
self.max_workers += 7
# max workers can't be less than min_workers
self.max_workers = max(self.min_workers, self.max_workers)
@@ -394,9 +346,6 @@ class AutoscalePool(WorkerPool):
self.scale_up_ct = 0
self.worker_count_max = 0
# last time we wrote current tasks, to avoid too much log spam
self.last_task_list_log = time.monotonic()
def produce_subsystem_metrics(self, metrics_object):
metrics_object.set('dispatcher_pool_scale_up_events', self.scale_up_ct)
metrics_object.set('dispatcher_pool_active_task_count', sum(len(w.managed_tasks) for w in self.workers))
@@ -436,7 +385,6 @@ class AutoscalePool(WorkerPool):
"""
orphaned = []
for w in self.workers[::]:
is_retirement_age = self.max_worker_lifetime_seconds is not None and w.age > self.max_worker_lifetime_seconds
if not w.alive:
# the worker process has exited
# 1. take the task it was running and enqueue the error
@@ -445,10 +393,6 @@ class AutoscalePool(WorkerPool):
# send them to another worker
logger.error('worker pid:{} is gone (exit={})'.format(w.pid, w.exitcode))
if w.current_task:
if w.current_task == {'task': RETIRED_SENTINEL_TASK}:
logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age))
self.workers.remove(w)
continue
if w.current_task != 'QUIT':
try:
for j in UnifiedJob.objects.filter(celery_task_id=w.current_task['uuid']):
@@ -459,7 +403,6 @@ class AutoscalePool(WorkerPool):
logger.warning(f'Worker was told to quit but has not, pid={w.pid}')
orphaned.extend(w.orphaned_tasks)
self.workers.remove(w)
elif w.idle and len(self.workers) > self.min_workers:
# the process has an empty queue (it's idle) and we have
# more processes in the pool than we need (> min)
@@ -468,22 +411,6 @@ class AutoscalePool(WorkerPool):
logger.debug('scaling down worker pid:{}'.format(w.pid))
w.quit()
self.workers.remove(w)
elif w.idle and is_retirement_age:
logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age))
w.quit()
self.workers.remove(w)
elif is_retirement_age and not w.retiring and not w.idle:
logger.info(
f"Worker pid:{w.pid} (age: {w.age:.0f}s) exceeded max lifetime ({self.max_worker_lifetime_seconds:.0f}s). "
"Signaling for graceful retirement."
)
# Send QUIT signal; worker will finish current task then exit.
w.quit()
# mark as retiring to reject any future tasks that might be assigned in meantime
w.retiring = True
if w.alive:
# if we discover a task manager invocation that's been running
# too long, reap it (because otherwise it'll just hold the postgres
@@ -536,14 +463,6 @@ class AutoscalePool(WorkerPool):
self.worker_count_max = new_worker_ct
return ret
@staticmethod
def fast_task_serialization(current_task):
try:
return str(current_task.get('task')) + ' - ' + str(sorted(current_task.get('args', []))) + ' - ' + str(sorted(current_task.get('kwargs', {})))
except Exception:
# just make sure this does not make things worse
return str(current_task)
def write(self, preferred_queue, body):
if 'guid' in body:
set_guid(body['guid'])
@@ -565,15 +484,6 @@ class AutoscalePool(WorkerPool):
if isinstance(body, dict):
task_name = body.get('task')
logger.warning(f'Workers maxed, queuing {task_name}, load: {sum(len(w.managed_tasks) for w in self.workers)} / {len(self.workers)}')
# Once every 10 seconds write out task list for debugging
if time.monotonic() - self.last_task_list_log >= 10.0:
task_counts = {}
for worker in self.workers:
task_slug = self.fast_task_serialization(worker.current_task)
task_counts.setdefault(task_slug, 0)
task_counts[task_slug] += 1
logger.info(f'Running tasks by count:\n{json.dumps(task_counts, indent=2)}')
self.last_task_list_log = time.monotonic()
return super(AutoscalePool, self).write(preferred_queue, body)
except Exception:
for conn in connections.all():

View File

@@ -4,9 +4,6 @@ import json
import time
from uuid import uuid4
from dispatcherd.publish import submit_task
from dispatcherd.utils import resolve_callable
from django_guid import get_guid
from django.conf import settings
@@ -96,19 +93,6 @@ class task:
@classmethod
def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
try:
from flags.state import flag_enabled
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
# At this point we have the import string, and submit_task wants the method, so back to that
actual_task = resolve_callable(cls.name)
return submit_task(actual_task, args=args, kwargs=kwargs, queue=queue, uuid=uuid, **kw)
except Exception:
logger.exception(f"[DISPATCHER] Failed to check for alternative dispatcherd implementation for {cls.name}")
# Continue with original implementation if anything fails
pass
# Original implementation follows
queue = queue or getattr(cls.queue, 'im_func', cls.queue)
if not queue:
msg = f'{cls.name}: Queue value required and may not be None'

View File

@@ -238,7 +238,7 @@ class AWXConsumerPG(AWXConsumerBase):
def run(self, *args, **kwargs):
super(AWXConsumerPG, self).run(*args, **kwargs)
logger.info(f"Running {self.name}, workers min={self.pool.min_workers} max={self.pool.max_workers}, listening to queues {self.queues}")
logger.info(f"Running worker {self.name} listening to queues {self.queues}")
init = False
while True:

View File

@@ -1,14 +0,0 @@
from dispatcherd.worker.task import TaskWorker
from django.db import connection
class AWXTaskWorker(TaskWorker):
def on_start(self) -> None:
"""Get worker connected so that first task it gets will be worked quickly"""
connection.ensure_connection()
def pre_task(self, message) -> None:
"""This should remedy bad connections that can not fix themselves"""
connection.close_if_unusable_or_obsolete()

View File

@@ -38,12 +38,5 @@ class PostRunError(Exception):
super(PostRunError, self).__init__(msg)
class PolicyEvaluationError(Exception):
def __init__(self, msg, status='failed', tb=''):
self.status = status
self.tb = tb
super(PolicyEvaluationError, self).__init__(msg)
class ReceptorNodeNotFound(RuntimeError):
pass

View File

@@ -33,7 +33,6 @@ from awx.main.utils.safe_yaml import sanitize_jinja
from awx.main.models.rbac import batch_role_ancestor_rebuilding
from awx.main.utils import ignore_inventory_computed_fields, get_licenser
from awx.main.utils.execution_environments import get_default_execution_environment
from awx.main.utils.inventory_vars import update_group_variables
from awx.main.signals import disable_activity_stream
from awx.main.constants import STANDARD_INVENTORY_UPDATE_ENV
@@ -458,19 +457,19 @@ class Command(BaseCommand):
"""
Update inventory variables from "all" group.
"""
# TODO: We disable variable overwrite here in case user-defined inventory variables get
# mangled. But we still need to figure out a better way of processing multiple inventory
# update variables mixing with each other.
# issue for this: https://github.com/ansible/awx/issues/11623
if self.inventory.kind == 'constructed' and self.inventory_source.overwrite_vars:
# NOTE: we had to add a exception case to not merge variables
# to make constructed inventory coherent
db_variables = self.all_group.variables
else:
db_variables = update_group_variables(
group_id=None, # `None` denotes the 'all' group (which doesn't have a pk).
newvars=self.all_group.variables,
dbvars=self.inventory.variables_dict,
invsrc_id=self.inventory_source.id,
inventory_id=self.inventory.id,
overwrite_vars=self.overwrite_vars,
)
db_variables = self.inventory.variables_dict
db_variables.update(self.all_group.variables)
if db_variables != self.inventory.variables_dict:
self.inventory.variables = json.dumps(db_variables)
self.inventory.save(update_fields=['variables'])

View File

@@ -2,21 +2,13 @@
# All Rights Reserved.
import logging
import yaml
import os
import redis
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from flags.state import flag_enabled
from dispatcherd.factories import get_control_from_settings
from dispatcherd import run_service
from dispatcherd.config import setup as dispatcher_setup
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.config import get_dispatcherd_config
from awx.main.dispatch.control import Control
from awx.main.dispatch.pool import AutoscalePool
from awx.main.dispatch.worker import AWXConsumerPG, TaskWorker
@@ -48,44 +40,18 @@ class Command(BaseCommand):
),
)
def verify_dispatcherd_socket(self):
if not os.path.exists(settings.DISPATCHERD_DEBUGGING_SOCKFILE):
raise CommandError('Dispatcher is not running locally')
def handle(self, *arg, **options):
if options.get('status'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
ctl = get_control_from_settings()
running_data = ctl.control_with_reply('status')
if len(running_data) != 1:
raise CommandError('Did not receive expected number of replies')
print(yaml.dump(running_data[0], default_flow_style=False))
return
else:
print(Control('dispatcher').status())
return
print(Control('dispatcher').status())
return
if options.get('schedule'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
print('NOT YET IMPLEMENTED')
return
else:
print(Control('dispatcher').schedule())
print(Control('dispatcher').schedule())
return
if options.get('running'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
ctl = get_control_from_settings()
running_data = ctl.control_with_reply('running')
print(yaml.dump(running_data, default_flow_style=False))
return
else:
print(Control('dispatcher').running())
return
print(Control('dispatcher').running())
return
if options.get('reload'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
print('NOT YET IMPLEMENTED')
return
else:
return Control('dispatcher').control({'control': 'reload'})
return Control('dispatcher').control({'control': 'reload'})
if options.get('cancel'):
cancel_str = options.get('cancel')
try:
@@ -94,36 +60,21 @@ class Command(BaseCommand):
cancel_data = [cancel_str]
if not isinstance(cancel_data, list):
cancel_data = [cancel_str]
print(Control('dispatcher').cancel(cancel_data))
return
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
ctl = get_control_from_settings()
results = []
for task_id in cancel_data:
# For each task UUID, send an individual cancel command
result = ctl.control_with_reply('cancel', data={'uuid': task_id})
results.append(result)
print(yaml.dump(results, default_flow_style=False))
return
else:
print(Control('dispatcher').cancel(cancel_data))
return
consumer = None
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
dispatcher_setup(get_dispatcherd_config(for_service=True))
run_service()
else:
consumer = None
try:
DispatcherMetricsServer().start()
except redis.exceptions.ConnectionError as exc:
raise CommandError(f'Dispatcher could not connect to redis, error: {exc}')
try:
DispatcherMetricsServer().start()
except redis.exceptions.ConnectionError as exc:
raise CommandError(f'Dispatcher could not connect to redis, error: {exc}')
try:
queues = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()]
consumer = AWXConsumerPG('dispatcher', TaskWorker(), queues, AutoscalePool(min_workers=4), schedule=settings.CELERYBEAT_SCHEDULE)
consumer.run()
except KeyboardInterrupt:
logger.debug('Terminating Task Dispatcher')
if consumer:
consumer.stop()
try:
queues = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()]
consumer = AWXConsumerPG('dispatcher', TaskWorker(), queues, AutoscalePool(min_workers=4), schedule=settings.CELERYBEAT_SCHEDULE)
consumer.run()
except KeyboardInterrupt:
logger.debug('Terminating Task Dispatcher')
if consumer:
consumer.stop()

View File

@@ -1,61 +0,0 @@
# Generated by Django 4.2.18 on 2025-02-27 20:35
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [('main', '0197_add_opa_query_path')]
operations = [
migrations.AlterField(
model_name='inventorysource',
name='source',
field=models.CharField(
choices=[
('file', 'File, Directory or Script'),
('constructed', 'Template additional groups and hostvars at runtime'),
('scm', 'Sourced from a Project'),
('ec2', 'Amazon EC2'),
('gce', 'Google Compute Engine'),
('azure_rm', 'Microsoft Azure Resource Manager'),
('vmware', 'VMware vCenter'),
('vmware_esxi', 'VMware ESXi'),
('satellite6', 'Red Hat Satellite 6'),
('openstack', 'OpenStack'),
('rhv', 'Red Hat Virtualization'),
('controller', 'Red Hat Ansible Automation Platform'),
('insights', 'Red Hat Insights'),
('terraform', 'Terraform State'),
('openshift_virtualization', 'OpenShift Virtualization'),
],
default=None,
max_length=32,
),
),
migrations.AlterField(
model_name='inventoryupdate',
name='source',
field=models.CharField(
choices=[
('file', 'File, Directory or Script'),
('constructed', 'Template additional groups and hostvars at runtime'),
('scm', 'Sourced from a Project'),
('ec2', 'Amazon EC2'),
('gce', 'Google Compute Engine'),
('azure_rm', 'Microsoft Azure Resource Manager'),
('vmware', 'VMware vCenter'),
('vmware_esxi', 'VMware ESXi'),
('satellite6', 'Red Hat Satellite 6'),
('openstack', 'OpenStack'),
('rhv', 'Red Hat Virtualization'),
('controller', 'Red Hat Ansible Automation Platform'),
('insights', 'Red Hat Insights'),
('terraform', 'Terraform State'),
('openshift_virtualization', 'OpenShift Virtualization'),
],
default=None,
max_length=32,
),
),
]

View File

@@ -0,0 +1,15 @@
# Generated by Django 4.2.10 on 2024-09-16 10:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('main', '0197_add_opa_query_path'),
]
operations = [
migrations.DeleteModel(
name='Profile',
),
]

View File

@@ -1,32 +0,0 @@
# Generated by Django 4.2.20 on 2025-04-24 09:08
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('main', '0198_alter_inventorysource_source_and_more'),
]
operations = [
migrations.CreateModel(
name='InventoryGroupVariablesWithHistory',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('variables', models.JSONField()),
('group', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='inventory_group_variables', to='main.group')),
(
'inventory',
models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='inventory_group_variables', to='main.inventory'),
),
],
),
migrations.AddConstraint(
model_name='inventorygroupvariableswithhistory',
constraint=models.UniqueConstraint(
fields=('inventory', 'group'), name='unique_inventory_group', violation_error_message='Inventory/Group combination must be unique.'
),
),
]

View File

@@ -0,0 +1,26 @@
# Generated by Django 4.2.10 on 2024-09-16 15:21
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('main', '0198_delete_profile'),
]
operations = [
# delete all sso application migrations
migrations.RunSQL("DELETE FROM django_migrations WHERE app = 'sso';"),
# delete all sso application content group permissions
migrations.RunSQL(
"DELETE FROM auth_group_permissions "
"WHERE permission_id IN "
"(SELECT id FROM auth_permission WHERE content_type_id in (SELECT id FROM django_content_type WHERE app_label = 'sso'));"
),
# delete all sso application content permissions
migrations.RunSQL("DELETE FROM auth_permission " "WHERE content_type_id IN (SELECT id FROM django_content_type WHERE app_label = 'sso');"),
# delete sso application content type
migrations.RunSQL("DELETE FROM django_content_type WHERE app_label = 'sso';"),
# drop sso application created table
migrations.RunSQL("DROP TABLE IF EXISTS sso_userenterpriseauth;"),
]

View File

@@ -0,0 +1,23 @@
# Generated by Django 4.2.10 on 2024-10-22 15:58
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('main', '0199_remove_sso_app_content'),
]
operations = [
migrations.AlterField(
model_name='inventorysource',
name='source',
field=models.CharField(default=None, max_length=32),
),
migrations.AlterField(
model_name='inventoryupdate',
name='source',
field=models.CharField(default=None, max_length=32),
),
]

View File

@@ -1,50 +0,0 @@
# Generated by Django 4.2.20 on 2025-04-22 15:54
import logging
from django.db import migrations, models
from awx.main.migrations._db_constraints import _rename_duplicates
logger = logging.getLogger(__name__)
def rename_jts(apps, schema_editor):
cls = apps.get_model('main', 'JobTemplate')
_rename_duplicates(cls)
def rename_projects(apps, schema_editor):
cls = apps.get_model('main', 'Project')
_rename_duplicates(cls)
def change_inventory_source_org_unique(apps, schema_editor):
cls = apps.get_model('main', 'InventorySource')
r = cls.objects.update(org_unique=False)
logger.info(f'Set database constraint rule for {r} inventory source objects')
class Migration(migrations.Migration):
dependencies = [
('main', '0199_inventorygroupvariableswithhistory_and_more'),
]
operations = [
migrations.RunPython(rename_jts, migrations.RunPython.noop),
migrations.RunPython(rename_projects, migrations.RunPython.noop),
migrations.AddField(
model_name='unifiedjobtemplate',
name='org_unique',
field=models.BooleanField(blank=True, default=True, editable=False, help_text='Used internally to selectively enforce database constraint on name'),
),
migrations.RunPython(change_inventory_source_org_unique, migrations.RunPython.noop),
migrations.AddConstraint(
model_name='unifiedjobtemplate',
constraint=models.UniqueConstraint(
condition=models.Q(('org_unique', True)), fields=('polymorphic_ctype', 'name', 'organization'), name='ujt_hard_name_constraint'
),
),
]

View File

@@ -0,0 +1,39 @@
# Generated by Django 4.2.10 on 2024-10-24 14:06
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('main', '0200_alter_inventorysource_source_and_more'),
]
operations = [
migrations.AlterUniqueTogether(
name='oauth2application',
unique_together=None,
),
migrations.RemoveField(
model_name='oauth2application',
name='organization',
),
migrations.RemoveField(
model_name='oauth2application',
name='user',
),
migrations.RemoveField(
model_name='activitystream',
name='o_auth2_access_token',
),
migrations.RemoveField(
model_name='activitystream',
name='o_auth2_application',
),
migrations.DeleteModel(
name='OAuth2AccessToken',
),
migrations.DeleteModel(
name='OAuth2Application',
),
]

View File

@@ -1,9 +0,0 @@
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('main', '0200_template_name_constraint'),
]
operations = []

View File

@@ -0,0 +1,44 @@
# Generated by Django 4.2.16 on 2024-12-18 16:05
from django.db import migrations, models
from awx.main.migrations._create_system_jobs import delete_clear_tokens_sjt
class Migration(migrations.Migration):
dependencies = [
('main', '0201_alter_oauth2application_unique_together_and_more'),
]
operations = [
migrations.RunPython(delete_clear_tokens_sjt, migrations.RunPython.noop),
migrations.AlterField(
model_name='systemjob',
name='job_type',
field=models.CharField(
blank=True,
choices=[
('cleanup_jobs', 'Remove jobs older than a certain number of days'),
('cleanup_activitystream', 'Remove activity stream entries older than a certain number of days'),
('cleanup_sessions', 'Removes expired browser sessions from the database'),
],
default='',
max_length=32,
),
),
migrations.AlterField(
model_name='systemjobtemplate',
name='job_type',
field=models.CharField(
blank=True,
choices=[
('cleanup_jobs', 'Remove jobs older than a certain number of days'),
('cleanup_activitystream', 'Remove activity stream entries older than a certain number of days'),
('cleanup_sessions', 'Removes expired browser sessions from the database'),
],
default='',
max_length=32,
),
),
]

View File

@@ -1,100 +0,0 @@
# Generated by Django 4.2.10 on 2024-09-16 10:22
from django.db import migrations, models
from awx.main.migrations._create_system_jobs import delete_clear_tokens_sjt
class Migration(migrations.Migration):
dependencies = [
('main', '0201_create_managed_creds'),
]
operations = [
migrations.DeleteModel(
name='Profile',
),
# Remove SSO app content
# delete all sso application migrations
migrations.RunSQL("DELETE FROM django_migrations WHERE app = 'sso';"),
# delete all sso application content group permissions
migrations.RunSQL(
"DELETE FROM auth_group_permissions "
"WHERE permission_id IN "
"(SELECT id FROM auth_permission WHERE content_type_id in (SELECT id FROM django_content_type WHERE app_label = 'sso'));"
),
# delete all sso application content permissions
migrations.RunSQL("DELETE FROM auth_permission " "WHERE content_type_id IN (SELECT id FROM django_content_type WHERE app_label = 'sso');"),
# delete sso application content type
migrations.RunSQL("DELETE FROM django_content_type WHERE app_label = 'sso';"),
# drop sso application created table
migrations.RunSQL("DROP TABLE IF EXISTS sso_userenterpriseauth;"),
# Alter inventory source source field
migrations.AlterField(
model_name='inventorysource',
name='source',
field=models.CharField(default=None, max_length=32),
),
migrations.AlterField(
model_name='inventoryupdate',
name='source',
field=models.CharField(default=None, max_length=32),
),
# Alter OAuth2Application unique together
migrations.AlterUniqueTogether(
name='oauth2application',
unique_together=None,
),
migrations.RemoveField(
model_name='oauth2application',
name='organization',
),
migrations.RemoveField(
model_name='oauth2application',
name='user',
),
migrations.RemoveField(
model_name='activitystream',
name='o_auth2_access_token',
),
migrations.RemoveField(
model_name='activitystream',
name='o_auth2_application',
),
migrations.DeleteModel(
name='OAuth2AccessToken',
),
migrations.DeleteModel(
name='OAuth2Application',
),
# Delete system token cleanup jobs, because tokens were deleted
migrations.RunPython(delete_clear_tokens_sjt, migrations.RunPython.noop),
migrations.AlterField(
model_name='systemjob',
name='job_type',
field=models.CharField(
blank=True,
choices=[
('cleanup_jobs', 'Remove jobs older than a certain number of days'),
('cleanup_activitystream', 'Remove activity stream entries older than a certain number of days'),
('cleanup_sessions', 'Removes expired browser sessions from the database'),
],
default='',
max_length=32,
),
),
migrations.AlterField(
model_name='systemjobtemplate',
name='job_type',
field=models.CharField(
blank=True,
choices=[
('cleanup_jobs', 'Remove jobs older than a certain number of days'),
('cleanup_activitystream', 'Remove activity stream entries older than a certain number of days'),
('cleanup_sessions', 'Removes expired browser sessions from the database'),
],
default='',
max_length=32,
),
),
]

View File

@@ -1,25 +0,0 @@
import logging
from django.db.models import Count
logger = logging.getLogger(__name__)
def _rename_duplicates(cls):
field = cls._meta.get_field('name')
max_len = field.max_length
for organization_id in cls.objects.order_by().values_list('organization_id', flat=True).distinct():
duplicate_data = cls.objects.values('name').filter(organization_id=organization_id).annotate(count=Count('name')).order_by().filter(count__gt=1)
for data in duplicate_data:
name = data['name']
for idx, ujt in enumerate(cls.objects.filter(name=name, organization_id=organization_id).order_by('created')):
if idx > 0:
suffix = f'_dup{idx}'
max_chars = max_len - len(suffix)
if len(ujt.name) >= max_chars:
ujt.name = ujt.name[:max_chars] + suffix
else:
ujt.name = ujt.name + suffix
logger.info(f'Renaming duplicate {cls._meta.model_name} to `{ujt.name}` because of duplicate name entry')
ujt.save(update_fields=['name'])

View File

@@ -33,7 +33,6 @@ from awx.main.models.inventory import ( # noqa
InventorySource,
InventoryUpdate,
SmartInventoryMembership,
InventoryGroupVariablesWithHistory,
)
from awx.main.models.jobs import ( # noqa
Job,

View File

@@ -24,7 +24,6 @@ from awx.main.managers import DeferJobCreatedManager
from awx.main.constants import MINIMAL_EVENTS
from awx.main.models.base import CreatedModifiedModel
from awx.main.utils import ignore_inventory_computed_fields, camelcase_to_underscore
from awx.main.utils.db import bulk_update_sorted_by_id
analytics_logger = logging.getLogger('awx.analytics.job_events')
@@ -603,7 +602,7 @@ class JobEvent(BasePlaybookEvent):
h.last_job_host_summary_id = host_mapping[h.id]
updated_hosts.add(h)
bulk_update_sorted_by_id(Host, updated_hosts, ['last_job_id', 'last_job_host_summary_id'])
Host.objects.bulk_update(list(updated_hosts), ['last_job_id', 'last_job_host_summary_id'], batch_size=100)
# Create/update Host Metrics
self._update_host_metrics(updated_hosts_list)

View File

@@ -1120,10 +1120,8 @@ class InventorySource(UnifiedJobTemplate, InventorySourceOptions, CustomVirtualE
def save(self, *args, **kwargs):
# if this is a new object, inherit organization from its inventory
if not self.pk:
self.org_unique = False # needed to exclude from unique (name, organization) constraint
if self.inventory and self.inventory.organization_id and not self.organization_id:
self.organization_id = self.inventory.organization_id
if not self.pk and self.inventory and self.inventory.organization_id and not self.organization_id:
self.organization_id = self.inventory.organization_id
# If update_fields has been specified, add our field names to it,
# if it hasn't been specified, then we're just doing a normal save.
@@ -1404,38 +1402,3 @@ class CustomInventoryScript(CommonModelNameNotUnique):
def get_absolute_url(self, request=None):
return reverse('api:inventory_script_detail', kwargs={'pk': self.pk}, request=request)
class InventoryGroupVariablesWithHistory(models.Model):
"""
Represents the inventory variables of one inventory group.
The purpose of this model is to persist the update history of the group
variables. The update history is maintained in another class
(`InventoryGroupVariables`), this class here is just a container for the
database storage.
"""
class Meta:
constraints = [
# Do not allow the same inventory/group combination more than once.
models.UniqueConstraint(
fields=["inventory", "group"],
name="unique_inventory_group",
violation_error_message=_("Inventory/Group combination must be unique."),
),
]
inventory = models.ForeignKey(
'Inventory',
related_name='inventory_group_variables',
null=True,
on_delete=models.CASCADE,
)
group = models.ForeignKey( # `None` denotes the 'all'-group.
'Group',
related_name='inventory_group_variables',
null=True,
on_delete=models.CASCADE,
)
variables = models.JSONField() # The group variables, including their history.

View File

@@ -358,6 +358,26 @@ class JobTemplate(
update_fields.append('organization_id')
return super(JobTemplate, self).save(*args, **kwargs)
def validate_unique(self, exclude=None):
"""Custom over-ride for JT specifically
because organization is inferred from project after full_clean is finished
thus the organization field is not yet set when validation happens
"""
errors = []
for ut in JobTemplate.SOFT_UNIQUE_TOGETHER:
kwargs = {'name': self.name}
if self.project:
kwargs['organization'] = self.project.organization_id
else:
kwargs['organization'] = None
qs = JobTemplate.objects.filter(**kwargs)
if self.pk:
qs = qs.exclude(pk=self.pk)
if qs.exists():
errors.append('%s with this (%s) combination already exists.' % (JobTemplate.__name__, ', '.join(set(ut) - {'polymorphic_ctype'})))
if errors:
raise ValidationError(errors)
def create_unified_job(self, **kwargs):
prevent_slicing = kwargs.pop('_prevent_slicing', False)
slice_ct = self.get_effective_slice_ct(kwargs)
@@ -384,26 +404,6 @@ class JobTemplate(
WorkflowJobNode.objects.create(**create_kwargs)
return job
def validate_unique(self, exclude=None):
"""Custom over-ride for JT specifically
because organization is inferred from project after full_clean is finished
thus the organization field is not yet set when validation happens
"""
errors = []
for ut in JobTemplate.SOFT_UNIQUE_TOGETHER:
kwargs = {'name': self.name}
if self.project:
kwargs['organization'] = self.project.organization_id
else:
kwargs['organization'] = None
qs = JobTemplate.objects.filter(**kwargs)
if self.pk:
qs = qs.exclude(pk=self.pk)
if qs.exists():
errors.append('%s with this (%s) combination already exists.' % (JobTemplate.__name__, ', '.join(set(ut) - {'polymorphic_ctype'})))
if errors:
raise ValidationError(errors)
def get_absolute_url(self, request=None):
return reverse('api:job_template_detail', kwargs={'pk': self.pk}, request=request)

View File

@@ -18,13 +18,11 @@ from collections import OrderedDict
# Django
from django.conf import settings
from django.db import models, connection, transaction
from django.db.models.constraints import UniqueConstraint
from django.core.exceptions import NON_FIELD_ERRORS
from django.utils.translation import gettext_lazy as _
from django.utils.timezone import now
from django.utils.encoding import smart_str
from django.contrib.contenttypes.models import ContentType
from flags.state import flag_enabled
# REST Framework
from rest_framework.exceptions import ParseError
@@ -113,10 +111,7 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, ExecutionEn
ordering = ('name',)
# unique_together here is intentionally commented out. Please make sure sub-classes of this model
# contain at least this uniqueness restriction: SOFT_UNIQUE_TOGETHER = [('polymorphic_ctype', 'name')]
# Unique name constraint - note that inventory source model is excluded from this constraint entirely
constraints = [
UniqueConstraint(fields=['polymorphic_ctype', 'name', 'organization'], condition=models.Q(org_unique=True), name='ujt_hard_name_constraint')
]
# unique_together = [('polymorphic_ctype', 'name', 'organization')]
old_pk = models.PositiveIntegerField(
null=True,
@@ -185,9 +180,6 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, ExecutionEn
)
labels = models.ManyToManyField("Label", blank=True, related_name='%(class)s_labels')
instance_groups = OrderedManyToManyField('InstanceGroup', blank=True, through='UnifiedJobTemplateInstanceGroupMembership')
org_unique = models.BooleanField(
blank=True, default=True, editable=False, help_text=_('Used internally to selectively enforce database constraint on name')
)
def get_absolute_url(self, request=None):
real_instance = self.get_real_instance()
@@ -1370,30 +1362,7 @@ class UnifiedJob(
traceback=self.result_traceback,
)
def get_start_kwargs(self):
needed = self.get_passwords_needed_to_start()
decrypted_start_args = decrypt_field(self, 'start_args')
if not decrypted_start_args or decrypted_start_args == '{}':
return None
try:
start_args = json.loads(decrypted_start_args)
except Exception:
logger.exception(f'Unexpected malformed start_args on unified_job={self.id}')
return None
opts = dict([(field, start_args.get(field, '')) for field in needed])
if not all(opts.values()):
missing_fields = ', '.join([k for k, v in opts.items() if not v])
self.job_explanation = u'Missing needed fields: %s.' % missing_fields
self.save(update_fields=['job_explanation'])
return opts
def pre_start(self):
def pre_start(self, **kwargs):
if not self.can_start:
self.job_explanation = u'%s is not in a startable state: %s, expecting one of %s' % (self._meta.verbose_name, self.status, str(('new', 'waiting')))
self.save(update_fields=['job_explanation'])
@@ -1414,11 +1383,26 @@ class UnifiedJob(
self.save(update_fields=['job_explanation'])
return (False, None)
opts = self.get_start_kwargs()
needed = self.get_passwords_needed_to_start()
try:
start_args = json.loads(decrypt_field(self, 'start_args'))
except Exception:
start_args = None
if opts and (not all(opts.values())):
if start_args in (None, ''):
start_args = kwargs
opts = dict([(field, start_args.get(field, '')) for field in needed])
if not all(opts.values()):
missing_fields = ', '.join([k for k, v in opts.items() if not v])
self.job_explanation = u'Missing needed fields: %s.' % missing_fields
self.save(update_fields=['job_explanation'])
return (False, None)
if 'extra_vars' in kwargs:
self.handle_extra_data(kwargs['extra_vars'])
# remove any job_explanations that may have been set while job was in pending
if self.job_explanation != "":
self.job_explanation = ""
@@ -1479,44 +1463,21 @@ class UnifiedJob(
def cancel_dispatcher_process(self):
"""Returns True if dispatcher running this job acknowledged request and sent SIGTERM"""
if not self.celery_task_id:
return False
return
canceled = []
# Special case for task manager (used during workflow job cancellation)
if not connection.get_autocommit():
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
try:
from dispatcherd.factories import get_control_from_settings
ctl = get_control_from_settings()
ctl.control('cancel', data={'uuid': self.celery_task_id})
except Exception:
logger.exception("Error sending cancel command to new dispatcher")
else:
try:
ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id], with_reply=False)
except Exception:
logger.exception("Error sending cancel command to legacy dispatcher")
# this condition is purpose-written for the task manager, when it cancels jobs in workflows
ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id], with_reply=False)
return True # task manager itself needs to act under assumption that cancel was received
# Standard case with reply
try:
# Use control and reply mechanism to cancel and obtain confirmation
timeout = 5
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
from dispatcherd.factories import get_control_from_settings
ctl = get_control_from_settings()
results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout)
# Check if cancel was successful by checking if we got any results
return bool(results and len(results) > 0)
else:
# Original implementation
canceled = ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id])
canceled = ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id])
except socket.timeout:
logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s')
except Exception:
logger.exception("error encountered when checking task status")
return bool(self.celery_task_id in canceled) # True or False, whether confirmation was obtained
def cancel(self, job_explanation=None, is_chain=False):

View File

@@ -19,9 +19,6 @@ from django.utils.timezone import now as tz_now
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
# django-flags
from flags.state import flag_enabled
from ansible_base.lib.utils.models import get_type_for_model
# django-ansible-base
@@ -51,7 +48,6 @@ from awx.main.signals import disable_activity_stream
from awx.main.constants import ACTIVE_STATES
from awx.main.scheduler.dependency_graph import DependencyGraph
from awx.main.scheduler.task_manager_models import TaskManagerModels
from awx.main.tasks.jobs import dispatch_waiting_jobs
import awx.main.analytics.subsystem_metrics as s_metrics
from awx.main.utils import decrypt_field
@@ -435,7 +431,6 @@ class TaskManager(TaskBase):
# 5 minutes to start pending jobs. If this limit is reached, pending jobs
# will no longer be started and will be started on the next task manager cycle.
self.time_delta_job_explanation = timedelta(seconds=30)
self.control_nodes_to_notify: set[str] = set()
super().__init__(prefix="task_manager")
def after_lock_init(self):
@@ -524,19 +519,16 @@ class TaskManager(TaskBase):
task.save()
task.log_lifecycle("waiting")
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
self.control_nodes_to_notify.add(task.get_queue_name())
else:
# apply_async does a NOTIFY to the channel dispatcher is listening to
# postgres will treat this as part of the transaction, which is what we want
if task.status != 'failed' and type(task) is not WorkflowJob:
task_cls = task._get_task_class()
task_cls.apply_async(
[task.pk],
opts,
queue=task.get_queue_name(),
uuid=task.celery_task_id,
)
# apply_async does a NOTIFY to the channel dispatcher is listening to
# postgres will treat this as part of the transaction, which is what we want
if task.status != 'failed' and type(task) is not WorkflowJob:
task_cls = task._get_task_class()
task_cls.apply_async(
[task.pk],
opts,
queue=task.get_queue_name(),
uuid=task.celery_task_id,
)
# In exception cases, like a job failing pre-start checks, we send the websocket status message.
# For jobs going into waiting, we omit this because of performance issues, as it should go to running quickly
@@ -729,8 +721,3 @@ class TaskManager(TaskBase):
for workflow_approval in self.get_expired_workflow_approvals():
self.timeout_approval_node(workflow_approval)
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
for controller_node in self.control_nodes_to_notify:
logger.info(f'Notifying node {controller_node} of new waiting jobs.')
dispatch_waiting_jobs.apply_async(queue=controller_node)

View File

@@ -7,7 +7,7 @@ from django.conf import settings
# AWX
from awx import MODE
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler')
@@ -20,16 +20,16 @@ def run_manager(manager, prefix):
manager().schedule()
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def task_manager():
run_manager(TaskManager, "task")
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def dependency_manager():
run_manager(DependencyManager, "dependency")
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def workflow_manager():
run_manager(WorkflowManager, "workflow")

View File

@@ -1 +1 @@
from . import callback, facts, helpers, host_indirect, host_metrics, jobs, receptor, system # noqa
from . import host_metrics, jobs, receptor, system # noqa

View File

@@ -8,13 +8,13 @@ import logging
from django.conf import settings
from django.utils.encoding import smart_str
from django.utils.timezone import now
from django.db import OperationalError
# django-ansible-base
from ansible_base.lib.logging.runtime import log_excess_runtime
# AWX
from awx.main.utils.db import bulk_update_sorted_by_id
from awx.main.models import Host
from awx.main.models.inventory import Host
logger = logging.getLogger('awx.main.tasks.facts')
@@ -22,29 +22,27 @@ system_tracking_logger = logging.getLogger('awx.analytics.system_tracking')
@log_excess_runtime(logger, debug_cutoff=0.01, msg='Inventory {inventory_id} host facts prepared for {written_ct} hosts, took {delta:.3f} s', add_log_data=True)
def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_data=None):
log_data = log_data or {}
def start_fact_cache(hosts, destination, log_data, timeout=None, inventory_id=None):
log_data['inventory_id'] = inventory_id
log_data['written_ct'] = 0
hosts_cached = []
# Create the fact_cache directory inside artifacts_dir
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
os.makedirs(fact_cache_dir, mode=0o700, exist_ok=True)
hosts_cached = list()
try:
os.makedirs(destination, mode=0o700)
except FileExistsError:
pass
if timeout is None:
timeout = settings.ANSIBLE_FACT_CACHE_TIMEOUT
last_write_time = None
last_filepath_written = None
for host in hosts:
hosts_cached.append(host.name)
hosts_cached.append(host)
if not host.ansible_facts_modified or (timeout and host.ansible_facts_modified < now() - datetime.timedelta(seconds=timeout)):
continue # facts are expired - do not write them
filepath = os.path.join(fact_cache_dir, host.name)
if not os.path.realpath(filepath).startswith(fact_cache_dir):
logger.error(f'facts for host {smart_str(host.name)} could not be cached')
filepath = os.sep.join(map(str, [destination, host.name]))
if not os.path.realpath(filepath).startswith(destination):
system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
continue
try:
@@ -52,21 +50,37 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
os.chmod(f.name, 0o600)
json.dump(host.ansible_facts, f)
log_data['written_ct'] += 1
last_write_time = os.path.getmtime(filepath)
last_filepath_written = filepath
except IOError:
logger.error(f'facts for host {smart_str(host.name)} could not be cached')
system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
continue
# Write summary file directly to the artifacts_dir
if inventory_id is not None:
summary_file = os.path.join(artifacts_dir, 'host_cache_summary.json')
summary_data = {
'last_write_time': last_write_time,
'hosts_cached': hosts_cached,
'written_ct': log_data['written_ct'],
}
with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(summary_data, f, indent=2)
if last_filepath_written:
return os.path.getmtime(last_filepath_written), hosts_cached
return None, hosts_cached
def raw_update_hosts(host_list):
Host.objects.bulk_update(host_list, ['ansible_facts', 'ansible_facts_modified'])
def update_hosts(host_list, max_tries=5):
if not host_list:
return
for i in range(max_tries):
try:
raw_update_hosts(host_list)
except OperationalError as exc:
# Deadlocks can happen if this runs at the same time as another large query
# inventory updates and updating last_job_host_summary are candidates for conflict
# but these would resolve easily on a retry
if i + 1 < max_tries:
logger.info(f'OperationalError (suspected deadlock) saving host facts retry {i}, message: {exc}')
continue
else:
raise
break
@log_excess_runtime(
@@ -75,54 +89,32 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
msg='Inventory {inventory_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s',
add_log_data=True,
)
def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=None):
log_data = log_data or {}
def finish_fact_cache(hosts_cached, destination, facts_write_time, log_data, job_id=None, inventory_id=None):
log_data['inventory_id'] = inventory_id
log_data['updated_ct'] = 0
log_data['unmodified_ct'] = 0
log_data['cleared_ct'] = 0
# The summary file is directly inside the artifacts dir
summary_path = os.path.join(artifacts_dir, 'host_cache_summary.json')
if not os.path.exists(summary_path):
logger.error(f'Missing summary file at {summary_path}')
return
try:
with open(summary_path, 'r', encoding='utf-8') as f:
summary = json.load(f)
facts_write_time = os.path.getmtime(summary_path) # After successful read
except (json.JSONDecodeError, OSError) as e:
logger.error(f'Error reading summary file at {summary_path}: {e}')
return
host_names = summary.get('hosts_cached', [])
hosts_cached = Host.objects.filter(name__in=host_names).order_by('id').iterator()
# Path where individual fact files were written
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
hosts_to_update = []
for host in hosts_cached:
filepath = os.path.join(fact_cache_dir, host.name)
if not os.path.realpath(filepath).startswith(fact_cache_dir):
logger.error(f'Invalid path for facts file: {filepath}')
filepath = os.sep.join(map(str, [destination, host.name]))
if not os.path.realpath(filepath).startswith(destination):
system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
continue
if os.path.exists(filepath):
# If the file changed since we wrote the last facts file, pre-playbook run...
modified = os.path.getmtime(filepath)
if not facts_write_time or modified >= facts_write_time:
try:
with codecs.open(filepath, 'r', encoding='utf-8') as f:
if (not facts_write_time) or modified > facts_write_time:
with codecs.open(filepath, 'r', encoding='utf-8') as f:
try:
ansible_facts = json.load(f)
except ValueError:
continue
if ansible_facts != host.ansible_facts:
except ValueError:
continue
host.ansible_facts = ansible_facts
host.ansible_facts_modified = now()
hosts_to_update.append(host)
logger.info(
f'New fact for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}',
system_tracking_logger.info(
'New fact for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)),
extra=dict(
inventory_id=host.inventory.id,
host_name=host.name,
@@ -132,8 +124,6 @@ def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=No
),
)
log_data['updated_ct'] += 1
else:
log_data['unmodified_ct'] += 1
else:
log_data['unmodified_ct'] += 1
else:
@@ -142,11 +132,9 @@ def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=No
host.ansible_facts = {}
host.ansible_facts_modified = now()
hosts_to_update.append(host)
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
system_tracking_logger.info('Facts cleared for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)))
log_data['cleared_ct'] += 1
if len(hosts_to_update) >= 100:
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
if len(hosts_to_update) > 100:
update_hosts(hosts_to_update)
hosts_to_update = []
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
update_hosts(hosts_to_update)

View File

@@ -12,7 +12,7 @@ from django.db import transaction
# Django flags
from flags.state import flag_enabled
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_task_queuename
from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit
from awx.main.models.event_query import EventQuery
@@ -77,14 +77,7 @@ def build_indirect_host_data(job: Job, job_event_queries: dict[str, dict[str, st
if jq_str_for_event not in compiled_jq_expressions:
compiled_jq_expressions[resolved_action] = jq.compile(jq_str_for_event)
compiled_jq = compiled_jq_expressions[resolved_action]
try:
data_source = compiled_jq.input(event.event_data['res']).all()
except Exception as e:
logger.warning(f'error for module {resolved_action} and data {event.event_data["res"]}: {e}')
continue
for data in data_source:
for data in compiled_jq.input(event.event_data['res']).all():
# From this jq result (specific to a single Ansible module), get index information about this host record
if not data.get('canonical_facts'):
if not facts_missing_logged:
@@ -159,7 +152,7 @@ def cleanup_old_indirect_host_entries() -> None:
IndirectManagedNodeAudit.objects.filter(created__lt=limit).delete()
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def save_indirect_host_entries(job_id: int, wait_for_events: bool = True) -> None:
try:
job = Job.objects.get(id=job_id)
@@ -201,7 +194,7 @@ def save_indirect_host_entries(job_id: int, wait_for_events: bool = True) -> Non
logger.exception(f'Error processing indirect host data for job_id={job_id}')
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def cleanup_and_save_indirect_host_entries_fallback() -> None:
if not flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
return

View File

@@ -7,18 +7,17 @@ from django.db.models import Count, F
from django.db.models.functions import TruncMonth
from django.utils.timezone import now
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch.publish import task
from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly
from awx.main.tasks.helpers import is_run_threshold_reached
from awx.conf.license import get_license
from ansible_base.lib.utils.db import advisory_lock
from awx.main.utils.db import bulk_update_sorted_by_id
logger = logging.getLogger('awx.main.tasks.host_metrics')
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def cleanup_host_metrics():
if is_run_threshold_reached(getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', None), getattr(settings, 'CLEANUP_HOST_METRICS_INTERVAL', 30) * 86400):
logger.info(f"Executing cleanup_host_metrics, last ran at {getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', '---')}")
@@ -29,7 +28,7 @@ def cleanup_host_metrics():
logger.info("Finished cleanup_host_metrics")
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def host_metric_summary_monthly():
"""Run cleanup host metrics summary monthly task each week"""
if is_run_threshold_reached(getattr(settings, 'HOST_METRIC_SUMMARY_TASK_LAST_TS', None), getattr(settings, 'HOST_METRIC_SUMMARY_TASK_INTERVAL', 7) * 86400):
@@ -147,9 +146,8 @@ class HostMetricSummaryMonthlyTask:
month = month + relativedelta(months=1)
# Create/Update stats
HostMetricSummaryMonthly.objects.bulk_create(self.records_to_create)
bulk_update_sorted_by_id(HostMetricSummaryMonthly, self.records_to_update, ['license_consumed', 'hosts_added', 'hosts_deleted'])
HostMetricSummaryMonthly.objects.bulk_create(self.records_to_create, batch_size=1000)
HostMetricSummaryMonthly.objects.bulk_update(self.records_to_update, ['license_consumed', 'hosts_added', 'hosts_deleted'], batch_size=1000)
# Set timestamp of last run
settings.HOST_METRIC_SUMMARY_TASK_LAST_TS = now()

View File

@@ -17,11 +17,11 @@ import urllib.parse as urlparse
# Django
from django.conf import settings
from django.db import transaction
# Shared code for the AWX platform
from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT, get_incontainer_path
# Runner
import ansible_runner
@@ -29,12 +29,9 @@ import ansible_runner
import git
from gitdb.exc import BadName as BadGitName
# Dispatcherd
from dispatcherd.publish import task
from dispatcherd.utils import serialize_task
# AWX
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_task_queuename
from awx.main.constants import (
PRIVILEGE_ESCALATION_METHODS,
@@ -42,13 +39,13 @@ from awx.main.constants import (
JOB_FOLDER_PREFIX,
MAX_ISOLATED_PATH_COLON_DELIMITER,
CONTAINER_VOLUMES_MOUNT_TYPES,
ACTIVE_STATES,
HOST_FACTS_FIELDS,
)
from awx.main.models import (
Instance,
Inventory,
InventorySource,
UnifiedJob,
Job,
AdHocCommand,
ProjectUpdate,
@@ -68,12 +65,11 @@ from awx.main.tasks.callback import (
RunnerCallbackForProjectUpdate,
RunnerCallbackForSystemJob,
)
from awx.main.tasks.policy import evaluate_policy
from awx.main.tasks.signals import with_signal_handling, signal_callback
from awx.main.tasks.receptor import AWXReceptorJob
from awx.main.tasks.facts import start_fact_cache, finish_fact_cache
from awx.main.tasks.system import update_smart_memberships_for_inventory, update_inventory_computed_fields, events_processed_hook
from awx.main.exceptions import AwxTaskError, PolicyEvaluationError, PostRunError, ReceptorNodeNotFound
from awx.main.exceptions import AwxTaskError, PostRunError, ReceptorNodeNotFound
from awx.main.utils.ansible import read_ansible_config
from awx.main.utils.safe_yaml import safe_dump, sanitize_jinja
from awx.main.utils.common import (
@@ -115,15 +111,6 @@ def with_path_cleanup(f):
return _wrapped
@task(on_duplicate='queue_one', bind=True, queue=get_task_queuename)
def dispatch_waiting_jobs(binder):
for uj in UnifiedJob.objects.filter(status='waiting', controller_node=settings.CLUSTER_HOST_ID).only('id', 'status', 'polymorphic_ctype', 'celery_task_id'):
kwargs = uj.get_start_kwargs()
if not kwargs:
kwargs = {}
binder.control('run', data={'task': serialize_task(uj._get_task_class()), 'args': [uj.id], 'kwargs': kwargs, 'uuid': uj.celery_task_id})
class BaseTask(object):
model = None
event_model = None
@@ -131,7 +118,6 @@ class BaseTask(object):
callback_class = RunnerCallback
def __init__(self):
self.instance = None
self.cleanup_paths = []
self.update_attempts = int(getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE) / 5)
self.runner_callback = self.callback_class(model=self.model)
@@ -319,8 +305,6 @@ class BaseTask(object):
# Add ANSIBLE_* settings to the subprocess environment.
for attr in dir(settings):
if attr == attr.upper() and attr.startswith('ANSIBLE_') and not attr.startswith('ANSIBLE_BASE_'):
if attr == 'ANSIBLE_STANDARD_SETTINGS_FILES':
continue # special case intended only for dynaconf use
env[attr] = str(getattr(settings, attr))
# Also set environment variables configured in AWX_TASK_ENV setting.
for key, value in settings.AWX_TASK_ENV.items():
@@ -459,21 +443,6 @@ class BaseTask(object):
"""
instance.log_lifecycle("finalize_run")
artifact_dir = os.path.join(private_data_dir, 'artifacts', str(self.instance.id))
collections_info = os.path.join(artifact_dir, 'collections.json')
ansible_version_file = os.path.join(artifact_dir, 'ansible_version.txt')
if os.path.exists(collections_info):
with open(collections_info) as ee_json_info:
ee_collections_info = json.loads(ee_json_info.read())
instance.installed_collections = ee_collections_info
instance.save(update_fields=['installed_collections'])
if os.path.exists(ansible_version_file):
with open(ansible_version_file) as ee_ansible_info:
ansible_version_info = ee_ansible_info.readline()
instance.ansible_version = ansible_version_info
instance.save(update_fields=['ansible_version'])
# Run task manager appropriately for speculative dependencies
if instance.unifiedjob_blocked_jobs.exists():
ScheduleTaskManager().schedule()
@@ -483,48 +452,27 @@ class BaseTask(object):
def should_use_fact_cache(self):
return False
def transition_status(self, pk: int) -> bool:
"""Atomically transition status to running, if False returned, another process got it"""
with transaction.atomic():
# Explanation of parts for the fetch:
# .values - avoid loading a full object, this is known to lead to deadlocks due to signals
# the signals load other related rows which another process may be locking, and happens in practice
# of=('self',) - keeps FK tables out of the lock list, another way deadlocks can happen
# .get - just load the single job
instance_data = UnifiedJob.objects.select_for_update(of=('self',)).values('status', 'cancel_flag').get(pk=pk)
# If status is not waiting (obtained under lock) then this process does not have clearence to run
if instance_data['status'] == 'waiting':
if instance_data['cancel_flag']:
updated_status = 'canceled'
else:
updated_status = 'running'
# Explanation of the update:
# .filter - again, do not load the full object
# .update - a bulk update on just that one row, avoid loading unintended data
UnifiedJob.objects.filter(pk=pk).update(status=updated_status, start_args='')
elif instance_data['status'] == 'running':
logger.info(f'Job {pk} is being ran by another process, exiting')
return False
return True
@with_path_cleanup
@with_signal_handling
def run(self, pk, **kwargs):
"""
Run the job/task and capture its output.
"""
if not self.instance: # Used to skip fetch for local runs
if not self.transition_status(pk):
logger.info(f'Job {pk} is being ran by another process, exiting')
return
self.instance = self.model.objects.get(pk=pk)
if self.instance.status != 'canceled' and self.instance.cancel_flag:
self.instance = self.update_model(self.instance.pk, start_args='', status='canceled')
if self.instance.status not in ACTIVE_STATES:
# Prevent starting the job if it has been reaped or handled by another process.
raise RuntimeError(f'Not starting {self.instance.status} task pk={pk} because {self.instance.status} is not a valid active state')
# Load the instance
self.instance = self.update_model(pk)
if self.instance.status != 'running':
logger.error(f'Not starting {self.instance.status} task pk={pk} because its status "{self.instance.status}" is not expected')
return
if self.instance.execution_environment_id is None:
from awx.main.signals import disable_activity_stream
with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, execution_environment=self.instance.resolve_execution_environment())
# self.instance because of the update_model pattern and when it's used in callback handlers
self.instance = self.update_model(pk, status='running', start_args='') # blank field to remove encrypted passwords
self.instance.websocket_emit_status("running")
status, rc = 'error', None
self.runner_callback.event_ct = 0
@@ -537,20 +485,12 @@ class BaseTask(object):
private_data_dir = None
try:
if self.instance.execution_environment_id is None:
from awx.main.signals import disable_activity_stream
with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, execution_environment=self.instance.resolve_execution_environment())
self.instance.send_notification_templates("running")
private_data_dir = self.build_private_data_dir(self.instance)
self.pre_run_hook(self.instance, private_data_dir)
evaluate_policy(self.instance)
self.build_project_dir(self.instance, private_data_dir)
self.instance.log_lifecycle("preparing_playbook")
if self.instance.cancel_flag or signal_callback():
logger.debug(f'detected pre-run cancel flag for {self.instance.log_format}')
self.instance = self.update_model(self.instance.pk, status='canceled')
if self.instance.status != 'running':
@@ -673,11 +613,12 @@ class BaseTask(object):
elif status == 'canceled':
self.instance = self.update_model(pk)
cancel_flag_value = getattr(self.instance, 'cancel_flag', False)
if cancel_flag_value is False:
if (cancel_flag_value is False) and signal_callback():
self.runner_callback.delay_update(skip_if_already_set=True, job_explanation="Task was canceled due to receiving a shutdown signal.")
status = 'failed'
except PolicyEvaluationError as exc:
self.runner_callback.delay_update(job_explanation=str(exc), result_traceback=str(exc))
elif cancel_flag_value is False:
self.runner_callback.delay_update(skip_if_already_set=True, job_explanation="The running ansible process received a shutdown signal.")
status = 'failed'
except ReceptorNodeNotFound as exc:
self.runner_callback.delay_update(job_explanation=str(exc))
except Exception:
@@ -703,9 +644,6 @@ class BaseTask(object):
# Field host_status_counts is used as a metric to check if event processing is finished
# we send notifications if it is, if not, callback receiver will send them
if not self.instance:
logger.error(f'Unified job pk={pk} appears to be deleted while running')
return
if (self.instance.host_status_counts is not None) or (not self.runner_callback.wrapup_event_dispatched):
events_processed_hook(self.instance)
@@ -802,7 +740,6 @@ class SourceControlMixin(BaseTask):
try:
# the job private_data_dir is passed so sync can download roles and collections there
sync_task = RunProjectUpdate(job_private_data_dir=private_data_dir)
sync_task.instance = local_project_sync # avoids "waiting" status check, performance
sync_task.run(local_project_sync.id)
local_project_sync.refresh_from_db()
self.instance = self.update_model(self.instance.pk, scm_revision=local_project_sync.scm_revision)
@@ -866,7 +803,7 @@ class SourceControlMixin(BaseTask):
self.release_lock(project)
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
class RunJob(SourceControlMixin, BaseTask):
"""
Run a job using ansible-playbook.
@@ -1154,8 +1091,8 @@ class RunJob(SourceControlMixin, BaseTask):
# where ansible expects to find it
if self.should_use_fact_cache():
job.log_lifecycle("start_job_fact_cache")
self.hosts_with_facts_cached = start_fact_cache(
job.get_hosts_for_fact_cache(), artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)), inventory_id=job.inventory_id
self.facts_write_time, self.hosts_with_facts_cached = start_fact_cache(
job.get_hosts_for_fact_cache(), os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'), inventory_id=job.inventory_id
)
def build_project_dir(self, job, private_data_dir):
@@ -1165,7 +1102,7 @@ class RunJob(SourceControlMixin, BaseTask):
super(RunJob, self).post_run_hook(job, status)
job.refresh_from_db(fields=['job_env'])
private_data_dir = job.job_env.get('AWX_PRIVATE_DATA_DIR')
if not private_data_dir:
if (not private_data_dir) or (not hasattr(self, 'facts_write_time')):
# If there's no private data dir, that means we didn't get into the
# actual `run()` call; this _usually_ means something failed in
# the pre_run_hook method
@@ -1173,7 +1110,9 @@ class RunJob(SourceControlMixin, BaseTask):
if self.should_use_fact_cache() and self.runner_callback.artifacts_processed:
job.log_lifecycle("finish_job_fact_cache")
finish_fact_cache(
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)),
self.hosts_with_facts_cached,
os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'),
facts_write_time=self.facts_write_time,
job_id=job.id,
inventory_id=job.inventory_id,
)
@@ -1189,7 +1128,7 @@ class RunJob(SourceControlMixin, BaseTask):
update_inventory_computed_fields.delay(inventory.id)
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
class RunProjectUpdate(BaseTask):
model = ProjectUpdate
event_model = ProjectUpdateEvent
@@ -1528,7 +1467,7 @@ class RunProjectUpdate(BaseTask):
return []
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
class RunInventoryUpdate(SourceControlMixin, BaseTask):
model = InventoryUpdate
event_model = InventoryUpdateEvent
@@ -1639,7 +1578,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
# Include any facts from input inventories so they can be used in filters
start_fact_cache(
input_inventory.hosts.only(*HOST_FACTS_FIELDS),
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(inventory_update.id)),
os.path.join(private_data_dir, 'artifacts', str(inventory_update.id), 'fact_cache'),
inventory_id=input_inventory.id,
)
@@ -1791,7 +1730,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc())
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
class RunAdHocCommand(BaseTask):
"""
Run an ad hoc command using ansible.
@@ -1944,7 +1883,7 @@ class RunAdHocCommand(BaseTask):
return d
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
class RunSystemJob(BaseTask):
model = SystemJob
event_model = SystemJobEvent

View File

@@ -1,458 +0,0 @@
import json
import tempfile
import contextlib
from pprint import pformat
from typing import Optional, Union
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from opa_client import OpaClient
from opa_client.base import BaseClient
from requests import HTTPError
from rest_framework import serializers
from rest_framework import fields
from awx.main import models
from awx.main.exceptions import PolicyEvaluationError
# Monkey patching opa_client.base.BaseClient to fix retries and timeout settings
_original_opa_base_client_init = BaseClient.__init__
def _opa_base_client_init_fix(
self,
host: str = "localhost",
port: int = 8181,
version: str = "v1",
ssl: bool = False,
cert: Optional[Union[str, tuple]] = None,
headers: Optional[dict] = None,
retries: int = 2,
timeout: float = 1.5,
):
_original_opa_base_client_init(self, host, port, version, ssl, cert, headers)
self.retries = retries
self.timeout = timeout
BaseClient.__init__ = _opa_base_client_init_fix
class _TeamSerializer(serializers.ModelSerializer):
class Meta:
model = models.Team
fields = ('id', 'name')
class _UserSerializer(serializers.ModelSerializer):
teams = serializers.SerializerMethodField()
class Meta:
model = models.User
fields = ('id', 'username', 'is_superuser', 'teams')
def get_teams(self, user: models.User):
teams = models.Team.access_qs(user, 'member')
return _TeamSerializer(many=True).to_representation(teams)
class _ExecutionEnvironmentSerializer(serializers.ModelSerializer):
class Meta:
model = models.ExecutionEnvironment
fields = (
'id',
'name',
'image',
'pull',
)
class _InstanceGroupSerializer(serializers.ModelSerializer):
class Meta:
model = models.InstanceGroup
fields = (
'id',
'name',
'capacity',
'jobs_running',
'jobs_total',
'max_concurrent_jobs',
'max_forks',
)
class _InventorySourceSerializer(serializers.ModelSerializer):
class Meta:
model = models.InventorySource
fields = ('id', 'name', 'source', 'status')
class _InventorySerializer(serializers.ModelSerializer):
inventory_sources = _InventorySourceSerializer(many=True)
class Meta:
model = models.Inventory
fields = (
'id',
'name',
'description',
'kind',
'total_hosts',
'total_groups',
'has_inventory_sources',
'total_inventory_sources',
'has_active_failures',
'hosts_with_active_failures',
'inventory_sources',
)
class _JobTemplateSerializer(serializers.ModelSerializer):
class Meta:
model = models.JobTemplate
fields = (
'id',
'name',
'job_type',
)
class _WorkflowJobTemplateSerializer(serializers.ModelSerializer):
class Meta:
model = models.WorkflowJobTemplate
fields = (
'id',
'name',
'job_type',
)
class _WorkflowJobSerializer(serializers.ModelSerializer):
class Meta:
model = models.WorkflowJob
fields = (
'id',
'name',
)
class _OrganizationSerializer(serializers.ModelSerializer):
class Meta:
model = models.Organization
fields = (
'id',
'name',
)
class _ProjectSerializer(serializers.ModelSerializer):
class Meta:
model = models.Project
fields = (
'id',
'name',
'status',
'scm_type',
'scm_url',
'scm_branch',
'scm_refspec',
'scm_clean',
'scm_track_submodules',
'scm_delete_on_update',
)
class _CredentialSerializer(serializers.ModelSerializer):
organization = _OrganizationSerializer()
class Meta:
model = models.Credential
fields = (
'id',
'name',
'description',
'organization',
'credential_type',
'managed',
'kind',
'cloud',
'kubernetes',
)
class _LabelSerializer(serializers.ModelSerializer):
organization = _OrganizationSerializer()
class Meta:
model = models.Label
fields = ('id', 'name', 'organization')
class JobSerializer(serializers.ModelSerializer):
created_by = _UserSerializer()
credentials = _CredentialSerializer(many=True)
execution_environment = _ExecutionEnvironmentSerializer()
instance_group = _InstanceGroupSerializer()
inventory = _InventorySerializer()
job_template = _JobTemplateSerializer()
labels = _LabelSerializer(many=True)
organization = _OrganizationSerializer()
project = _ProjectSerializer()
extra_vars = fields.SerializerMethodField()
hosts_count = fields.SerializerMethodField()
workflow_job = fields.SerializerMethodField()
workflow_job_template = fields.SerializerMethodField()
class Meta:
model = models.Job
fields = (
'id',
'name',
'created',
'created_by',
'credentials',
'execution_environment',
'extra_vars',
'forks',
'hosts_count',
'instance_group',
'inventory',
'job_template',
'job_type',
'job_type_name',
'labels',
'launch_type',
'limit',
'launched_by',
'organization',
'playbook',
'project',
'scm_branch',
'scm_revision',
'workflow_job',
'workflow_job_template',
)
def get_extra_vars(self, obj: models.Job):
return json.loads(obj.display_extra_vars())
def get_hosts_count(self, obj: models.Job):
return obj.hosts.count()
def get_workflow_job(self, obj: models.Job):
workflow_job: models.WorkflowJob = obj.get_workflow_job()
if workflow_job is None:
return None
return _WorkflowJobSerializer().to_representation(workflow_job)
def get_workflow_job_template(self, obj: models.Job):
workflow_job: models.WorkflowJob = obj.get_workflow_job()
if workflow_job is None:
return None
workflow_job_template: models.WorkflowJobTemplate = workflow_job.workflow_job_template
if workflow_job_template is None:
return None
return _WorkflowJobTemplateSerializer().to_representation(workflow_job_template)
class OPAResultSerializer(serializers.Serializer):
allowed = fields.BooleanField(required=True)
violations = fields.ListField(child=fields.CharField())
class OPA_AUTH_TYPES:
NONE = 'None'
TOKEN = 'Token'
CERTIFICATE = 'Certificate'
@contextlib.contextmanager
def opa_cert_file():
"""
Context manager that creates temporary certificate files for OPA authentication.
For mTLS (mutual TLS), we need:
- Client certificate and key for client authentication
- CA certificate (optional) for server verification
Returns:
tuple: (client_cert_path, verify_path)
- client_cert_path: Path to client cert file or None if not using client cert
- verify_path: Path to CA cert file, True to use system CA store, or False for no verification
"""
client_cert_temp = None
ca_temp = None
try:
# Case 1: Full mTLS with client cert and optional CA cert
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.CERTIFICATE:
# Create client certificate file (required for mTLS)
client_cert_temp = tempfile.NamedTemporaryFile(delete=True, mode='w', suffix=".pem")
client_cert_temp.write(settings.OPA_AUTH_CLIENT_CERT)
client_cert_temp.write("\n")
client_cert_temp.write(settings.OPA_AUTH_CLIENT_KEY)
client_cert_temp.write("\n")
client_cert_temp.flush()
# If CA cert is provided, use it for server verification
# Otherwise, use system CA store (True)
if settings.OPA_AUTH_CA_CERT:
ca_temp = tempfile.NamedTemporaryFile(delete=True, mode='w', suffix=".pem")
ca_temp.write(settings.OPA_AUTH_CA_CERT)
ca_temp.write("\n")
ca_temp.flush()
verify_path = ca_temp.name
else:
verify_path = True # Use system CA store
yield (client_cert_temp.name, verify_path)
# Case 2: TLS with only server verification (no client cert)
elif settings.OPA_SSL:
# If CA cert is provided, use it for server verification
# Otherwise, use system CA store (True)
if settings.OPA_AUTH_CA_CERT:
ca_temp = tempfile.NamedTemporaryFile(delete=True, mode='w', suffix=".pem")
ca_temp.write(settings.OPA_AUTH_CA_CERT)
ca_temp.write("\n")
ca_temp.flush()
verify_path = ca_temp.name
else:
verify_path = True # Use system CA store
yield (None, verify_path)
# Case 3: No TLS
else:
yield (None, False)
finally:
# Clean up temporary files
if client_cert_temp:
client_cert_temp.close()
if ca_temp:
ca_temp.close()
@contextlib.contextmanager
def opa_client(headers=None):
with opa_cert_file() as cert_files:
cert, verify = cert_files
with OpaClient(
host=settings.OPA_HOST,
port=settings.OPA_PORT,
headers=headers,
ssl=settings.OPA_SSL,
cert=cert,
timeout=settings.OPA_REQUEST_TIMEOUT,
retries=settings.OPA_REQUEST_RETRIES,
) as client:
# Workaround for https://github.com/Turall/OPA-python-client/issues/32
# by directly setting cert and verify on requests.session
client._session.cert = cert
client._session.verify = verify
yield client
def evaluate_policy(instance):
# Policy evaluation for Policy as Code feature
if not settings.OPA_HOST:
return
if not isinstance(instance, models.Job):
return
instance.log_lifecycle("evaluate_policy")
input_data = JobSerializer(instance=instance).data
headers = settings.OPA_AUTH_CUSTOM_HEADERS
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.TOKEN:
headers.update({'Authorization': 'Bearer {}'.format(settings.OPA_AUTH_TOKEN)})
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.CERTIFICATE and not settings.OPA_SSL:
raise PolicyEvaluationError(_('OPA_AUTH_TYPE=Certificate requires OPA_SSL to be enabled.'))
cert_settings_missing = []
if settings.OPA_AUTH_TYPE == OPA_AUTH_TYPES.CERTIFICATE:
if not settings.OPA_AUTH_CLIENT_CERT:
cert_settings_missing += ['OPA_AUTH_CLIENT_CERT']
if not settings.OPA_AUTH_CLIENT_KEY:
cert_settings_missing += ['OPA_AUTH_CLIENT_KEY']
if not settings.OPA_AUTH_CA_CERT:
cert_settings_missing += ['OPA_AUTH_CA_CERT']
if cert_settings_missing:
raise PolicyEvaluationError(_('Following certificate settings are missing for OPA_AUTH_TYPE=Certificate: {}').format(cert_settings_missing))
query_paths = [
('Organization', instance.organization.opa_query_path),
('Inventory', instance.inventory.opa_query_path),
('Job template', instance.job_template.opa_query_path),
]
violations = dict()
errors = dict()
try:
with opa_client(headers=headers) as client:
for path_type, query_path in query_paths:
response = dict()
try:
if not query_path:
continue
response = client.query_rule(input_data=input_data, package_path=query_path)
except HTTPError as e:
message = _('Call to OPA failed. Exception: {}').format(e)
try:
error_data = e.response.json()
except ValueError:
errors[path_type] = message
continue
error_code = error_data.get("code")
error_message = error_data.get("message")
if error_code or error_message:
message = _('Call to OPA failed. Code: {}, Message: {}').format(error_code, error_message)
errors[path_type] = message
continue
except Exception as e:
errors[path_type] = _('Call to OPA failed. Exception: {}').format(e)
continue
result = response.get('result')
if result is None:
errors[path_type] = _('Call to OPA did not return a "result" property. The path refers to an undefined document.')
continue
result_serializer = OPAResultSerializer(data=result)
if not result_serializer.is_valid():
errors[path_type] = _('OPA policy returned invalid result.')
continue
result_data = result_serializer.validated_data
if not result_data.get("allowed") and (result_violations := result_data.get("violations")):
violations[path_type] = result_violations
format_results = dict()
if any(errors[e] for e in errors):
format_results["Errors"] = errors
if any(violations[v] for v in violations):
format_results["Violations"] = violations
if violations or errors:
raise PolicyEvaluationError(pformat(format_results, width=80))
except Exception as e:
raise PolicyEvaluationError(_('This job cannot be executed due to a policy violation or error. See the following details:\n{}').format(e))

View File

@@ -32,7 +32,7 @@ from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER
from awx.main.tasks.signals import signal_state, signal_callback, SignalExit
from awx.main.models import Instance, InstanceLink, UnifiedJob, ReceptorAddress
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch.publish import task
# Receptorctl
from receptorctl.socket_interface import ReceptorControl
@@ -852,7 +852,7 @@ def reload_receptor():
raise RuntimeError("Receptor reload failed")
@task_awx()
@task()
def write_receptor_config():
"""
This task runs async on each control node, K8S only.
@@ -875,7 +875,7 @@ def write_receptor_config():
reload_receptor()
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def remove_deprovisioned_node(hostname):
InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)
InstanceLink.objects.filter(target__instance__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)

View File

@@ -14,21 +14,16 @@ class SignalExit(Exception):
class SignalState:
# SIGTERM: Sent by supervisord to process group on shutdown
# SIGUSR1: The dispatcherd cancel signal
signals = (signal.SIGTERM, signal.SIGINT, signal.SIGUSR1)
def reset(self):
for for_signal in self.signals:
self.signal_flags[for_signal] = False
self.original_methods[for_signal] = None
self.sigterm_flag = False
self.sigint_flag = False
self.is_active = False # for nested context managers
self.original_sigterm = None
self.original_sigint = None
self.raise_exception = False
def __init__(self):
self.signal_flags = {}
self.original_methods = {}
self.reset()
def raise_if_needed(self):
@@ -36,28 +31,31 @@ class SignalState:
self.raise_exception = False # so it is not raised a second time in error handling
raise SignalExit()
def set_signal_flag(self, *args, for_signal=None):
self.signal_flags[for_signal] = True
logger.info(f'Processed signal {for_signal}, set exit flag')
def set_sigterm_flag(self, *args):
self.sigterm_flag = True
self.raise_if_needed()
def set_sigint_flag(self, *args):
self.sigint_flag = True
self.raise_if_needed()
def connect_signals(self):
for for_signal in self.signals:
self.original_methods[for_signal] = signal.getsignal(for_signal)
signal.signal(for_signal, lambda *args, for_signal=for_signal: self.set_signal_flag(*args, for_signal=for_signal))
self.original_sigterm = signal.getsignal(signal.SIGTERM)
self.original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGTERM, self.set_sigterm_flag)
signal.signal(signal.SIGINT, self.set_sigint_flag)
self.is_active = True
def restore_signals(self):
for for_signal in self.signals:
original_method = self.original_methods[for_signal]
signal.signal(for_signal, original_method)
# if we got a signal while context manager was active, call parent methods.
if self.signal_flags[for_signal]:
if callable(original_method):
try:
original_method()
except Exception as exc:
logger.info(f'Error processing original {for_signal} signal, error: {str(exc)}')
signal.signal(signal.SIGTERM, self.original_sigterm)
signal.signal(signal.SIGINT, self.original_sigint)
# if we got a signal while context manager was active, call parent methods.
if self.sigterm_flag:
if callable(self.original_sigterm):
self.original_sigterm()
if self.sigint_flag:
if callable(self.original_sigint):
self.original_sigint()
self.reset()
@@ -65,7 +63,7 @@ signal_state = SignalState()
def signal_callback():
return any(signal_state.signal_flags[for_signal] for for_signal in signal_state.signals)
return bool(signal_state.sigterm_flag or signal_state.sigint_flag)
def with_signal_handling(f):

View File

@@ -1,77 +1,78 @@
# Python
from collections import namedtuple
import functools
import importlib
import itertools
import json
import logging
import os
import psycopg
from io import StringIO
from contextlib import redirect_stdout
import shutil
import time
from collections import namedtuple
from contextlib import redirect_stdout
from datetime import datetime
from distutils.version import LooseVersion as Version
from io import StringIO
from datetime import datetime
# Runner
import ansible_runner.cleanup
import psycopg
from ansible_base.lib.utils.db import advisory_lock
# django-ansible-base
from ansible_base.resource_registry.tasks.sync import SyncExecutor
# Django
from django.conf import settings
from django.db import connection, transaction, DatabaseError, IntegrityError
from django.db.models.fields.related import ForeignKey
from django.utils.timezone import now, timedelta
from django.utils.encoding import smart_str
from django.contrib.auth.models import User
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext_noop
from django.core.cache import cache
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.query import QuerySet
# Django-CRUM
from crum import impersonate
# Django flags
from flags.state import flag_enabled
# Runner
import ansible_runner.cleanup
# dateutil
from dateutil.parser import parse as parse_date
# Django
from django.conf import settings
from django.contrib.auth.models import User
from django.core.cache import cache
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError, connection, transaction
from django.db.models.fields.related import ForeignKey
from django.db.models.query import QuerySet
from django.utils.encoding import smart_str
from django.utils.timezone import now, timedelta
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext_noop
# Django flags
from flags.state import flag_enabled
from rest_framework.exceptions import PermissionDenied
# django-ansible-base
from ansible_base.resource_registry.tasks.sync import SyncExecutor
from ansible_base.lib.utils.db import advisory_lock
# AWX
from awx import __version__ as awx_application_version
from awx.conf import settings_registry
from awx.main import analytics
from awx.main.access import access_registry
from awx.main.analytics.subsystem_metrics import DispatcherMetrics
from awx.main.constants import ACTIVE_STATES, ERROR_STATES
from awx.main.consumers import emit_channel_notification
from awx.main.dispatch import get_task_queuename, reaper
from awx.main.dispatch.publish import task as task_awx
from awx.main.models import (
Schedule,
TowerScheduleState,
Instance,
InstanceGroup,
Inventory,
Job,
Notification,
Schedule,
SmartInventoryMembership,
TowerScheduleState,
UnifiedJob,
Notification,
Inventory,
SmartInventoryMembership,
Job,
convert_jsonfields,
)
from awx.main.constants import ACTIVE_STATES, ERROR_STATES
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_task_queuename, reaper
from awx.main.utils.common import ignore_inventory_computed_fields, ignore_inventory_group_removal
from awx.main.utils.reload import stop_local_services
from awx.main.tasks.helpers import is_run_threshold_reached
from awx.main.tasks.host_indirect import save_indirect_host_entries
from awx.main.tasks.receptor import administrative_workunit_reaper, get_receptor_ctl, worker_cleanup, worker_info, write_receptor_config
from awx.main.utils.common import ignore_inventory_computed_fields, ignore_inventory_group_removal
from awx.main.utils.reload import stop_local_services
from dispatcherd.publish import task
from awx.main.tasks.receptor import get_receptor_ctl, worker_info, worker_cleanup, administrative_workunit_reaper, write_receptor_config
from awx.main.consumers import emit_channel_notification
from awx.main import analytics
from awx.conf import settings_registry
from awx.main.analytics.subsystem_metrics import DispatcherMetrics
from rest_framework.exceptions import PermissionDenied
logger = logging.getLogger('awx.main.tasks.system')
@@ -82,12 +83,7 @@ Try upgrading OpenSSH or providing your private key in an different format. \
'''
def _run_dispatch_startup_common():
"""
Execute the common startup initialization steps.
This includes updating schedules, syncing instance membership, and starting
local reaping and resetting metrics.
"""
def dispatch_startup():
startup_logger = logging.getLogger('awx.main.tasks')
# TODO: Enable this on VM installs
@@ -97,14 +93,14 @@ def _run_dispatch_startup_common():
try:
convert_jsonfields()
except Exception:
logger.exception("Failed JSON field conversion, skipping.")
logger.exception("Failed json field conversion, skipping.")
startup_logger.debug("Syncing schedules")
startup_logger.debug("Syncing Schedules")
for sch in Schedule.objects.all():
try:
sch.update_computed_fields()
except Exception:
logger.exception("Failed to rebuild schedule %s.", sch)
logger.exception("Failed to rebuild schedule {}.".format(sch))
#
# When the dispatcher starts, if the instance cannot be found in the database,
@@ -124,67 +120,25 @@ def _run_dispatch_startup_common():
apply_cluster_membership_policies()
cluster_node_heartbeat()
reaper.startup_reaping()
reaper.reap_waiting(grace_period=0)
m = DispatcherMetrics()
m.reset_values()
def _legacy_dispatch_startup():
"""
Legacy branch for startup: simply performs reaping of waiting jobs with a zero grace period.
"""
logger.debug("Legacy dispatcher: calling reaper.reap_waiting with grace_period=0")
reaper.reap_waiting(grace_period=0)
def _dispatcherd_dispatch_startup():
"""
New dispatcherd branch for startup: uses the control API to re-submit waiting jobs.
"""
logger.debug("Dispatcherd enabled: dispatching waiting jobs via control channel")
from awx.main.tasks.jobs import dispatch_waiting_jobs
dispatch_waiting_jobs.apply_async(queue=get_task_queuename())
def dispatch_startup():
"""
System initialization at startup.
First, execute the common logic.
Then, if FEATURE_DISPATCHERD_ENABLED is enabled, re-submit waiting jobs via the control API;
otherwise, fall back to legacy reaping of waiting jobs.
"""
_run_dispatch_startup_common()
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
_dispatcherd_dispatch_startup()
else:
_legacy_dispatch_startup()
def inform_cluster_of_shutdown():
"""
Clean system shutdown that marks the current instance offline.
In legacy mode, it also reaps waiting jobs.
In dispatcherd mode, it relies on dispatcherd's built-in cleanup.
"""
try:
inst = Instance.objects.get(hostname=settings.CLUSTER_HOST_ID)
inst.mark_offline(update_last_seen=True, errors=_('Instance received normal shutdown signal'))
except Instance.DoesNotExist:
logger.exception("Cluster host not found: %s", settings.CLUSTER_HOST_ID)
return
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
logger.debug("Dispatcherd mode: no extra reaping required for instance %s", inst.hostname)
else:
this_inst = Instance.objects.get(hostname=settings.CLUSTER_HOST_ID)
this_inst.mark_offline(update_last_seen=True, errors=_('Instance received normal shutdown signal'))
try:
logger.debug("Legacy mode: reaping waiting jobs for instance %s", inst.hostname)
reaper.reap_waiting(inst, grace_period=0)
reaper.reap_waiting(this_inst, grace_period=0)
except Exception:
logger.exception("Failed to reap waiting jobs for %s", inst.hostname)
logger.warning("Normal shutdown processed for instance %s; instance removed from capacity pool.", inst.hostname)
logger.exception('failed to reap waiting jobs for {}'.format(this_inst.hostname))
logger.warning('Normal shutdown signal for instance {}, removed self from capacity pool.'.format(this_inst.hostname))
except Exception:
logger.exception('Encountered problem with normal shutdown signal.')
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def migrate_jsonfield(table, pkfield, columns):
batchsize = 10000
with advisory_lock(f'json_migration_{table}', wait=False) as acquired:
@@ -230,7 +184,7 @@ def migrate_jsonfield(table, pkfield, columns):
logger.warning(f"Migration of {table} to jsonb is finished.")
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def apply_cluster_membership_policies():
from awx.main.signals import disable_activity_stream
@@ -342,7 +296,7 @@ def apply_cluster_membership_policies():
logger.debug('Cluster policy computation finished in {} seconds'.format(time.time() - started_compute))
@task_awx(queue='tower_settings_change')
@task(queue='tower_settings_change')
def clear_setting_cache(setting_keys):
# log that cache is being cleared
logger.info(f"clear_setting_cache of keys {setting_keys}")
@@ -355,7 +309,7 @@ def clear_setting_cache(setting_keys):
cache.delete_many(cache_keys)
@task_awx(queue='tower_broadcast_all')
@task(queue='tower_broadcast_all')
def delete_project_files(project_path):
# TODO: possibly implement some retry logic
lock_file = project_path + '.lock'
@@ -373,7 +327,7 @@ def delete_project_files(project_path):
logger.exception('Could not remove lock file {}'.format(lock_file))
@task_awx(queue='tower_broadcast_all')
@task(queue='tower_broadcast_all')
def profile_sql(threshold=1, minutes=1):
if threshold <= 0:
cache.delete('awx-profile-sql-threshold')
@@ -383,7 +337,7 @@ def profile_sql(threshold=1, minutes=1):
logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes))
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def send_notifications(notification_list, job_id=None):
if not isinstance(notification_list, list):
raise TypeError("notification_list should be of type list")
@@ -428,13 +382,13 @@ def events_processed_hook(unified_job):
save_indirect_host_entries.delay(unified_job.id)
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def gather_analytics():
if is_run_threshold_reached(getattr(settings, 'AUTOMATION_ANALYTICS_LAST_GATHER', None), settings.AUTOMATION_ANALYTICS_GATHER_INTERVAL):
analytics.gather()
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def purge_old_stdout_files():
nowtime = time.time()
for f in os.listdir(settings.JOBOUTPUT_ROOT):
@@ -496,18 +450,18 @@ class CleanupImagesAndFiles:
cls.run_remote(this_inst, **kwargs)
@task_awx(queue='tower_broadcast_all')
@task(queue='tower_broadcast_all')
def handle_removed_image(remove_images=None):
"""Special broadcast invocation of this method to handle case of deleted EE"""
CleanupImagesAndFiles.run(remove_images=remove_images, file_pattern='')
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def cleanup_images_and_files():
CleanupImagesAndFiles.run(image_prune=True)
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def cluster_node_health_check(node):
"""
Used for the health check endpoint, refreshes the status of the instance, but must be ran on target node
@@ -526,7 +480,7 @@ def cluster_node_health_check(node):
this_inst.local_health_check()
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def execution_node_health_check(node):
if node == '':
logger.warning('Remote health check incorrectly called with blank string')
@@ -594,16 +548,8 @@ def inspect_established_receptor_connections(mesh_status):
def inspect_execution_and_hop_nodes(instance_list):
with advisory_lock('inspect_execution_and_hop_nodes_lock', wait=False):
node_lookup = {inst.hostname: inst for inst in instance_list}
try:
ctl = get_receptor_ctl()
except FileNotFoundError:
logger.error('Receptor daemon not running, skipping execution node check')
return
try:
mesh_status = ctl.simple_command('status')
except ValueError as exc:
logger.error(f'Error running receptorctl status command, error: {str(exc)}')
return
ctl = get_receptor_ctl()
mesh_status = ctl.simple_command('status')
inspect_established_receptor_connections(mesh_status)
@@ -651,109 +597,8 @@ def inspect_execution_and_hop_nodes(instance_list):
execution_node_health_check.apply_async([hostname])
@task_awx(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
@task(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
"""
Original implementation for AWX dispatcher.
Uses worker_tasks from bind_kwargs to track running tasks.
"""
# Run common instance management logic
this_inst, instance_list, lost_instances = _heartbeat_instance_management()
if this_inst is None:
return # Early return case from instance management
# Check versions
_heartbeat_check_versions(this_inst, instance_list)
# Handle lost instances
_heartbeat_handle_lost_instances(lost_instances, this_inst)
# Run local reaper - original implementation using worker_tasks
if worker_tasks is not None:
active_task_ids = []
for task_list in worker_tasks.values():
active_task_ids.extend(task_list)
# Convert dispatch_time to datetime
ref_time = datetime.fromisoformat(dispatch_time) if dispatch_time else now()
reaper.reap(instance=this_inst, excluded_uuids=active_task_ids, ref_time=ref_time)
if max(len(task_list) for task_list in worker_tasks.values()) <= 1:
reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=ref_time)
@task(queue=get_task_queuename, bind=True)
def adispatch_cluster_node_heartbeat(binder):
"""
Dispatcherd implementation.
Uses Control API to get running tasks.
"""
# Run common instance management logic
this_inst, instance_list, lost_instances = _heartbeat_instance_management()
if this_inst is None:
return # Early return case from instance management
# Check versions
_heartbeat_check_versions(this_inst, instance_list)
# Handle lost instances
_heartbeat_handle_lost_instances(lost_instances, this_inst)
# Get running tasks using dispatcherd API
active_task_ids = _get_active_task_ids_from_dispatcherd(binder)
if active_task_ids is None:
logger.warning("No active task IDs retrieved from dispatcherd, skipping reaper")
return # Failed to get task IDs, don't attempt reaping
# Run local reaper using tasks from dispatcherd
ref_time = now() # No dispatch_time in dispatcherd version
logger.debug(f"Running reaper with {len(active_task_ids)} excluded UUIDs")
reaper.reap(instance=this_inst, excluded_uuids=active_task_ids, ref_time=ref_time)
# If waiting jobs are hanging out, resubmit them
if UnifiedJob.objects.filter(controller_node=settings.CLUSTER_HOST_ID, status='waiting').exists():
from awx.main.tasks.jobs import dispatch_waiting_jobs
dispatch_waiting_jobs.apply_async(queue=get_task_queuename())
def _get_active_task_ids_from_dispatcherd(binder):
"""
Retrieve active task IDs from the dispatcherd control API.
Returns:
list: List of active task UUIDs
None: If there was an error retrieving the data
"""
active_task_ids = []
try:
logger.debug("Querying dispatcherd API for running tasks")
data = binder.control('running')
# Extract UUIDs from the running data
# Process running data: first item is a dict with node_id and task entries
data.pop('node_id', None)
# Extract task UUIDs from data structure
for task_key, task_value in data.items():
if isinstance(task_value, dict) and 'uuid' in task_value:
active_task_ids.append(task_value['uuid'])
logger.debug(f"Found active task with UUID: {task_value['uuid']}")
elif isinstance(task_key, str):
# Handle case where UUID might be the key
active_task_ids.append(task_key)
logger.debug(f"Found active task with key: {task_key}")
logger.debug(f"Retrieved {len(active_task_ids)} active task IDs from dispatcherd")
return active_task_ids
except Exception:
logger.exception("Failed to get running tasks from dispatcherd")
return None
def _heartbeat_instance_management():
"""Common logic for heartbeat instance management."""
logger.debug("Cluster node heartbeat task.")
nowtime = now()
instance_list = list(Instance.objects.filter(node_state__in=(Instance.States.READY, Instance.States.UNAVAILABLE, Instance.States.INSTALLED)))
@@ -780,7 +625,7 @@ def _heartbeat_instance_management():
this_inst.local_health_check()
if startup_event and this_inst.capacity != 0:
logger.warning(f'Rejoining the cluster as instance {this_inst.hostname}. Prior last_seen {last_last_seen}')
return None, None, None # Early return case
return
elif not last_last_seen:
logger.warning(f'Instance does not have recorded last_seen, updating to {nowtime}')
elif (nowtime - last_last_seen) > timedelta(seconds=settings.CLUSTER_NODE_HEARTBEAT_PERIOD + 2):
@@ -792,14 +637,8 @@ def _heartbeat_instance_management():
logger.warning(f'Recreated instance record {this_inst.hostname} after unexpected removal')
this_inst.local_health_check()
else:
logger.error("Cluster Host Not Found: {}".format(settings.CLUSTER_HOST_ID))
return None, None, None
return this_inst, instance_list, lost_instances
def _heartbeat_check_versions(this_inst, instance_list):
"""Check versions across instances and determine if shutdown is needed."""
raise RuntimeError("Cluster Host Not Found: {}".format(settings.CLUSTER_HOST_ID))
# IFF any node has a greater version than we do, then we'll shutdown services
for other_inst in instance_list:
if other_inst.node_type in ('execution', 'hop'):
continue
@@ -816,9 +655,6 @@ def _heartbeat_check_versions(this_inst, instance_list):
stop_local_services(communicate=False)
raise RuntimeError("Shutting down.")
def _heartbeat_handle_lost_instances(lost_instances, this_inst):
"""Handle lost instances by reaping their jobs and marking them offline."""
for other_inst in lost_instances:
try:
explanation = "Job reaped due to instance shutdown"
@@ -849,8 +685,17 @@ def _heartbeat_handle_lost_instances(lost_instances, this_inst):
else:
logger.exception('No SQL state available. Error marking {} as lost'.format(other_inst.hostname))
# Run local reaper
if worker_tasks is not None:
active_task_ids = []
for task_list in worker_tasks.values():
active_task_ids.extend(task_list)
reaper.reap(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time))
if max(len(task_list) for task_list in worker_tasks.values()) <= 1:
reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time))
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def awx_receptor_workunit_reaper():
"""
When an AWX job is launched via receptor, files such as status, stdin, and stdout are created
@@ -873,16 +718,8 @@ def awx_receptor_workunit_reaper():
if not settings.RECEPTOR_RELEASE_WORK:
return
logger.debug("Checking for unreleased receptor work units")
try:
receptor_ctl = get_receptor_ctl()
except FileNotFoundError:
logger.info('Receptorctl sockfile not found for workunit reaper, doing nothing')
return
try:
receptor_work_list = receptor_ctl.simple_command("work list")
except ValueError as exc:
logger.info(f'Error getting work list for workunit reaper, error: {str(exc)}')
return
receptor_ctl = get_receptor_ctl()
receptor_work_list = receptor_ctl.simple_command("work list")
unit_ids = [id for id in receptor_work_list]
jobs_with_unreleased_receptor_units = UnifiedJob.objects.filter(work_unit_id__in=unit_ids).exclude(status__in=ACTIVE_STATES)
@@ -896,7 +733,7 @@ def awx_receptor_workunit_reaper():
administrative_workunit_reaper(receptor_work_list)
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def awx_k8s_reaper():
if not settings.RECEPTOR_RELEASE_WORK:
return
@@ -919,7 +756,7 @@ def awx_k8s_reaper():
logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group))
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def awx_periodic_scheduler():
lock_session_timeout_milliseconds = settings.TASK_MANAGER_LOCK_TIMEOUT * 1000
with advisory_lock('awx_periodic_scheduler_lock', lock_session_timeout_milliseconds=lock_session_timeout_milliseconds, wait=False) as acquired:
@@ -978,7 +815,7 @@ def awx_periodic_scheduler():
emit_channel_notification('schedules-changed', dict(id=schedule.id, group_name="schedules"))
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def handle_failure_notifications(task_ids):
"""A task-ified version of the method that sends notifications."""
found_task_ids = set()
@@ -993,7 +830,7 @@ def handle_failure_notifications(task_ids):
logger.warning(f'Could not send notifications for {deleted_tasks} because they were not found in the database')
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def update_inventory_computed_fields(inventory_id):
"""
Signal handler and wrapper around inventory.update_computed_fields to
@@ -1043,7 +880,7 @@ def update_smart_memberships_for_inventory(smart_inventory):
return False
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def update_host_smart_inventory_memberships():
smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False)
changed_inventories = set([])
@@ -1059,7 +896,7 @@ def update_host_smart_inventory_memberships():
smart_inventory.update_computed_fields()
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def delete_inventory(inventory_id, user_id, retries=5):
# Delete inventory as user
if user_id is None:
@@ -1121,7 +958,7 @@ def _reconstruct_relationships(copy_mapping):
new_obj.save()
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, permission_check_func=None):
logger.debug('Deep copy {} from {} to {}.'.format(model_name, obj_pk, new_obj_pk))
@@ -1176,7 +1013,7 @@ def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, p
update_inventory_computed_fields.delay(new_obj.id)
@task_awx(queue=get_task_queuename)
@task(queue=get_task_queuename)
def periodic_resource_sync():
if not getattr(settings, 'RESOURCE_SERVER', None):
logger.debug("Skipping periodic resource_sync, RESOURCE_SERVER not configured")

View File

@@ -8,12 +8,5 @@
"CONTROLLER_PASSWORD": "fooo",
"CONTROLLER_USERNAME": "fooo",
"CONTROLLER_OAUTH_TOKEN": "",
"CONTROLLER_VERIFY_SSL": "False",
"AAP_HOSTNAME": "https://foo.invalid",
"AAP_PASSWORD": "fooo",
"AAP_USERNAME": "fooo",
"AAP_VALIDATE_CERTS": "False",
"CONTROLLER_REQUEST_TIMEOUT": "fooo",
"AAP_REQUEST_TIMEOUT": "fooo",
"AAP_TOKEN": ""
"CONTROLLER_VERIFY_SSL": "False"
}

View File

@@ -1,9 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
vars:
sleep_interval: 5
tasks:
- name: sleep for a specified interval
command: sleep '{{ sleep_interval }}'

View File

@@ -1,3 +0,0 @@
[all:vars]
a=value_a
b=value_b

View File

@@ -1,57 +0,0 @@
import time
import logging
from dispatcherd.publish import task
from django.db import connection
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as old_task
from ansible_base.lib.utils.db import advisory_lock
logger = logging.getLogger(__name__)
@old_task(queue=get_task_queuename)
def sleep_task(seconds=10, log=False):
if log:
logger.info('starting sleep_task')
time.sleep(seconds)
if log:
logger.info('finished sleep_task')
@task()
def sleep_break_connection(seconds=0.2):
"""
Interact with the database in an intentionally breaking way.
After this finishes, queries made by this connection are expected to error
with "the connection is closed"
This is obviously a problem for any task that comes afterwards.
So this is used to break things so that the fixes may be demonstrated.
"""
with connection.cursor() as cursor:
cursor.execute(f"SET idle_session_timeout = '{seconds / 2}s';")
logger.info(f'sleeping for {seconds}s > {seconds / 2}s session timeout')
time.sleep(seconds)
for i in range(1, 3):
logger.info(f'\nRunning query number {i}')
try:
with connection.cursor() as cursor:
cursor.execute("SELECT 1;")
logger.info(' query worked, not expected')
except Exception as exc:
logger.info(f' query errored as expected\ntype: {type(exc)}\nstr: {str(exc)}')
logger.info(f'Connection present: {bool(connection.connection)}, reports closed: {getattr(connection.connection, "closed", "not_found")}')
@task()
def advisory_lock_exception():
time.sleep(0.2) # so it can fill up all the workers... hacky for now
with advisory_lock('advisory_lock_exception', lock_session_timeout_milliseconds=20):
raise RuntimeError('this is an intentional error')

View File

@@ -87,8 +87,8 @@ def mock_analytic_post():
{
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_PASSWORD': '',
},
True,
('redhat_user', 'redhat_pass'),
@@ -98,8 +98,8 @@ def mock_analytic_post():
{
'REDHAT_USERNAME': None,
'REDHAT_PASSWORD': None,
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR
'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
},
True,
('subs_user', 'subs_pass'),
@@ -109,8 +109,8 @@ def mock_analytic_post():
{
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR
'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
},
True,
('subs_user', 'subs_pass'),
@@ -120,8 +120,8 @@ def mock_analytic_post():
{
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_PASSWORD': '',
},
False,
None, # No request should be made
@@ -131,8 +131,8 @@ def mock_analytic_post():
{
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_PASSWORD': '',
},
False,
None, # Invalid, no request should be made
@@ -150,24 +150,3 @@ def test_ship_credential(setting_map, expected_result, expected_auth, temp_analy
assert mock_analytic_post.call_args[1]['auth'] == expected_auth
else:
mock_analytic_post.assert_not_called()
@pytest.mark.django_db
def test_gather_cleanup_on_auth_failure(mock_valid_license, temp_analytic_tar):
settings.INSIGHTS_TRACKING_STATE = True
settings.AUTOMATION_ANALYTICS_URL = 'https://example.com/api'
settings.REDHAT_USERNAME = 'test_user'
settings.REDHAT_PASSWORD = 'test_password'
with tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') as temp_file:
temp_file_path = temp_file.name
try:
with mock.patch('awx.main.analytics.core.ship', return_value=False):
with mock.patch('awx.main.analytics.core.package', return_value=temp_file_path):
gather(module=importlib.import_module(__name__), collection_type='scheduled')
assert not os.path.exists(temp_file_path), "Temp file was not cleaned up after ship failure"
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View File

@@ -30,7 +30,6 @@ EXPECTED_VALUES = {
'awx_license_instance_free': 0,
'awx_pending_jobs_total': 0,
'awx_database_connections_total': 1,
'awx_license_expiry': 0,
}

View File

@@ -97,8 +97,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_PASSWORD': '',
},
('redhat_user', 'redhat_pass'),
None,
@@ -109,8 +109,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR
'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
},
('subs_user', 'subs_pass'),
None,
@@ -121,8 +121,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'SUBSCRIPTIONS_USERNAME': '',
'SUBSCRIPTIONS_PASSWORD': '',
},
None,
ERROR_MISSING_USER,
@@ -133,8 +133,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR
'SUBSCRIPTIONS_USERNAME': 'subs_user',
'SUBSCRIPTIONS_PASSWORD': 'subs_pass', # NOSONAR
},
('redhat_user', 'redhat_pass'),
None,
@@ -145,8 +145,8 @@ class TestAnalyticsGenericView:
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user', # NOSONAR
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'SUBSCRIPTIONS_USERNAME': 'subs_user', # NOSONAR
'SUBSCRIPTIONS_PASSWORD': '',
},
None,
ERROR_MISSING_PASSWORD,
@@ -155,36 +155,26 @@ class TestAnalyticsGenericView:
)
@pytest.mark.django_db
def test__send_to_analytics_credentials(self, settings_map, expected_auth, expected_error_keyword):
"""
Test _send_to_analytics with various combinations of credentials.
"""
with override_settings(**settings_map):
request = RequestFactory().post('/some/path')
view = AnalyticsGenericView()
if expected_auth:
with mock.patch('awx.api.views.analytics.OIDCClient') as mock_oidc_client:
# Configure the mock OIDCClient instance and its make_request method
mock_client_instance = mock.Mock()
mock_oidc_client.return_value = mock_client_instance
mock_client_instance.make_request.return_value = mock.Mock(status_code=200)
with mock.patch('requests.request') as mock_request:
mock_request.return_value = mock.Mock(status_code=200)
analytic_url = view._get_analytics_url(request.path)
response = view._send_to_analytics(request, 'POST')
# Assertions
# Assert OIDCClient instantiation
expected_client_id, expected_client_secret = expected_auth
mock_oidc_client.assert_called_once_with(expected_client_id, expected_client_secret)
# Assert make_request call
mock_client_instance.make_request.assert_called_once_with(
mock_request.assert_called_once_with(
'POST',
analytic_url,
headers=mock.ANY,
auth=expected_auth,
verify=mock.ANY,
params=mock.ANY,
headers=mock.ANY,
json=mock.ANY,
params=mock.ANY,
timeout=mock.ANY,
)
assert response.status_code == 200
@@ -196,64 +186,3 @@ class TestAnalyticsGenericView:
# mock_error_response.assert_called_once_with(expected_error_keyword, remote=False)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.data['error']['keyword'] == expected_error_keyword
@pytest.mark.django_db
@pytest.mark.parametrize(
"settings_map, expected_auth",
[
# Test case 1: Username and password should be used for basic auth
(
{
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass', # NOSONAR
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
},
('redhat_user', 'redhat_pass'),
),
# Test case 2: Client ID and secret should be used for basic auth
(
{
'INSIGHTS_TRACKING_STATE': True,
'REDHAT_USERNAME': '',
'REDHAT_PASSWORD': '',
'SUBSCRIPTIONS_CLIENT_ID': 'subs_user',
'SUBSCRIPTIONS_CLIENT_SECRET': 'subs_pass', # NOSONAR
},
None,
),
],
)
def test__send_to_analytics_fallback_to_basic_auth(self, settings_map, expected_auth):
"""
Test _send_to_analytics with basic auth fallback.
"""
with override_settings(**settings_map):
request = RequestFactory().post('/some/path')
view = AnalyticsGenericView()
with mock.patch('awx.api.views.analytics.OIDCClient') as mock_oidc_client, mock.patch(
'awx.api.views.analytics.AnalyticsGenericView._base_auth_request'
) as mock_base_auth_request:
# Configure the mock OIDCClient instance and its make_request method
mock_client_instance = mock.Mock()
mock_oidc_client.return_value = mock_client_instance
mock_client_instance.make_request.side_effect = requests.RequestException("Incorrect credentials")
analytic_url = view._get_analytics_url(request.path)
view._send_to_analytics(request, 'POST')
if expected_auth:
# assert mock_base_auth_request called with expected_auth
mock_base_auth_request.assert_called_once_with(
request,
'POST',
analytic_url,
expected_auth[0],
expected_auth[1],
mock.ANY,
)
else:
# assert mock_base_auth_request not called
mock_base_auth_request.assert_not_called()

View File

@@ -8,7 +8,6 @@ from django.core.exceptions import ValidationError
from awx.api.versioning import reverse
from awx.main.models import InventorySource, Inventory, ActivityStream
from awx.main.utils.inventory_vars import update_group_variables
@pytest.fixture
@@ -691,241 +690,3 @@ class TestConstructedInventory:
assert inv_r.data['url'] != const_r.data['url']
assert inv_r.data['related']['constructed_url'] == url_const
assert const_r.data['related']['constructed_url'] == url_const
@pytest.mark.django_db
class TestInventoryAllVariables:
@staticmethod
def simulate_update_from_source(inv_src, variables_dict, overwrite_vars=True):
"""
Update `inventory` with variables `variables_dict` from source
`inv_src`.
"""
# Perform an update from source the same way it is done in
# `inventory_import.Command._update_inventory`.
new_vars = update_group_variables(
group_id=None, # `None` denotes the 'all' group (which doesn't have a pk).
newvars=variables_dict,
dbvars=inv_src.inventory.variables_dict,
invsrc_id=inv_src.id,
inventory_id=inv_src.inventory.id,
overwrite_vars=overwrite_vars,
)
inv_src.inventory.variables = json.dumps(new_vars)
inv_src.inventory.save(update_fields=["variables"])
return new_vars
def update_and_verify(self, inv_src, new_vars, expect=None, overwrite_vars=True, teststep=None):
"""
Helper: Update from source and verify the new inventory variables.
:param inv_src: An inventory source object with its inventory property
set to the inventory fixture of the called.
:param dict new_vars: The variables of the inventory source `inv_src`.
:param dict expect: (optional) The expected variables state of the
inventory after the update. If not set or None, expect `new_vars`.
:param bool overwrite_vars: The status of the inventory source option
'overwrite variables'. Default is `True`.
:raise AssertionError: If the inventory does not contain the expected
variables after the update.
"""
self.simulate_update_from_source(inv_src, new_vars, overwrite_vars=overwrite_vars)
if teststep is not None:
assert inv_src.inventory.variables_dict == (expect if expect is not None else new_vars), f"Test step {teststep}"
else:
assert inv_src.inventory.variables_dict == (expect if expect is not None else new_vars)
def test_set_variables_through_inventory_details_update(self, inventory, patch, admin_user):
"""
Set an inventory variable by changing the inventory details, simulating
a user edit.
"""
# a: x
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"a": "x"}
def test_variables_set_by_user_persist_update_from_src(self, inventory, inventory_source, patch, admin_user):
"""
Verify the special behavior that a variable which originates from a user
edit (instead of a source update), is not removed from the inventory
when a source update with overwrite_vars=True does not contain that
variable. This behavior is considered special because a variable which
originates from a source would actually be deleted.
In addition, verify that an existing variable which was set by a user
edit can be overwritten by a source update.
"""
# Set two variables via user edit.
patch(
url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}),
data={'variables': '{"a": "a_from_user", "b": "b_from_user"}'},
user=admin_user,
expect=200,
)
inventory.refresh_from_db()
assert inventory.variables_dict == {'a': 'a_from_user', 'b': 'b_from_user'}
# Update from a source which contains only one of the two variables from
# the previous update.
self.simulate_update_from_source(inventory_source, {'a': 'a_from_source'})
# Verify inventory variables.
assert inventory.variables_dict == {'a': 'a_from_source', 'b': 'b_from_user'}
def test_variables_set_through_src_get_removed_on_update_from_same_src(self, inventory, inventory_source, patch, admin_user):
"""
Verify that a variable which originates from a source update, is removed
from the inventory when a source update with overwrite_vars=True does
not contain that variable.
In addition, verify that an existing variable which was set by a user
edit can be overwritten by a source update.
"""
# Set two variables via update from source.
self.simulate_update_from_source(inventory_source, {'a': 'a_from_source', 'b': 'b_from_source'})
# Verify inventory variables.
assert inventory.variables_dict == {'a': 'a_from_source', 'b': 'b_from_source'}
# Update from the same source which now contains only one of the two
# variables from the previous update.
self.simulate_update_from_source(inventory_source, {'b': 'b_from_source'})
# Verify the variable has been deleted from the inventory.
assert inventory.variables_dict == {'b': 'b_from_source'}
def test_overwrite_variables_through_inventory_details_update(self, inventory, patch, admin_user):
"""
Set and update the inventory variables multiple times by changing the
inventory details via api, simulating user edits.
Any variables update by means of an inventory details update shall
overwright all existing inventory variables.
"""
# a: x
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"a": "x"}
# a: x2
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x2'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"a": "x2"}
# b: y
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'b: y'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"b": "y"}
def test_inventory_group_variables_internal_data(self, inventory, patch, admin_user):
"""
Basic verification of how variable updates are stored internally.
.. Warning::
This test verifies a specific implementation of the inventory
variables update business logic. It may deliver false negatives if
the implementation changes.
"""
# x: a
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'a: x'}, user=admin_user, expect=200)
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'a': [[-1, 'x']]}
# b: y
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'b: y'}, user=admin_user, expect=200)
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'b': [[-1, 'y']]}
def test_update_then_user_change(self, inventory, patch, admin_user, inventory_source):
"""
1. Update inventory vars by means of an inventory source update.
2. Update inventory vars by editing the inventory details (aka a 'user
update'), thereby changing variables values and deleting variables
from the inventory.
.. Warning::
This test partly relies on a specific implementation of the
inventory variables update business logic. It may deliver false
negatives if the implementation changes.
"""
assert inventory_source.inventory_id == inventory.pk # sanity
# ---- Test step 1: Set variables by updating from an inventory source.
self.simulate_update_from_source(inventory_source, {'foo': 'foo_from_source', 'bar': 'bar_from_source'})
# Verify inventory variables.
assert inventory.variables_dict == {'foo': 'foo_from_source', 'bar': 'bar_from_source'}
# Verify internal storage of variables data. Note that this is
# implementation specific
assert inventory.inventory_group_variables.count() == 1
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'foo': [[inventory_source.id, 'foo_from_source']], 'bar': [[inventory_source.id, 'bar_from_source']]}
# ---- Test step 2: Change the variables by editing the inventory details.
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'foo: foo_from_user'}, user=admin_user, expect=200)
inventory.refresh_from_db()
# Verify that variable `foo` contains the new value, and that variable
# `bar` has been deleted from the inventory.
assert inventory.variables_dict == {"foo": "foo_from_user"}
# Verify internal storage of variables data. Note that this is
# implementation specific
inventory.inventory_group_variables.count() == 1
igv = inventory.inventory_group_variables.first()
assert igv.variables == {'foo': [[-1, 'foo_from_user']]}
def test_monotonic_deletions(self, inventory, patch, admin_user):
"""
Verify the variables history logic for monotonic deletions.
Monotonic in this context means that the variables are deleted in the
reverse order of their creation.
1. Set inventory variable x: 0, expect INV={x: 0}
(The following steps use overwrite_variables=False)
2. Update from source A={x: 1}, expect INV={x: 1}
3. Update from source B={x: 2}, expect INV={x: 2}
4. Update from source B={}, expect INV={x: 1}
5. Update from source A={}, expect INV={x: 0}
"""
inv_src_a = InventorySource.objects.create(name="inv-src-A", inventory=inventory, source="ec2")
inv_src_b = InventorySource.objects.create(name="inv-src-B", inventory=inventory, source="ec2")
# Test step 1:
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'x: 0'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"x": 0}
# Test step 2: Source A overwrites value of var x
self.update_and_verify(inv_src_a, {"x": 1}, teststep=2)
# Test step 3: Source A overwrites value of var x
self.update_and_verify(inv_src_b, {"x": 2}, teststep=3)
# Test step 4: Value of var x from source A reappears
self.update_and_verify(inv_src_b, {}, expect={"x": 1}, teststep=4)
# Test step 5: Value of var x from initial user edit reappears
self.update_and_verify(inv_src_a, {}, expect={"x": 0}, teststep=5)
def test_interleaved_deletions(self, inventory, patch, admin_user, inventory_source):
"""
Verify the variables history logic for interleaved deletions.
Interleaved in this context means that the variables are deleted in a
different order than the sequence of their creation.
1. Set inventory variable x: 0, expect INV={x: 0}
2. Update from source A={x: 1}, expect INV={x: 1}
3. Update from source B={x: 2}, expect INV={x: 2}
4. Update from source C={x: 3}, expect INV={x: 3}
5. Update from source B={}, expect INV={x: 3}
6. Update from source C={}, expect INV={x: 1}
"""
inv_src_a = InventorySource.objects.create(name="inv-src-A", inventory=inventory, source="ec2")
inv_src_b = InventorySource.objects.create(name="inv-src-B", inventory=inventory, source="ec2")
inv_src_c = InventorySource.objects.create(name="inv-src-C", inventory=inventory, source="ec2")
# Test step 1. Set inventory variable x: 0
patch(url=reverse('api:inventory_detail', kwargs={'pk': inventory.pk}), data={'variables': 'x: 0'}, user=admin_user, expect=200)
inventory.refresh_from_db()
assert inventory.variables_dict == {"x": 0}
# Test step 2: Source A overwrites value of var x
self.update_and_verify(inv_src_a, {"x": 1}, teststep=2)
# Test step 3: Source B overwrites value of var x
self.update_and_verify(inv_src_b, {"x": 2}, teststep=3)
# Test step 4: Source C overwrites value of var x
self.update_and_verify(inv_src_c, {"x": 3}, teststep=4)
# Test step 5: Value of var x from source C remains unchanged
self.update_and_verify(inv_src_b, {}, expect={"x": 3}, teststep=5)
# Test step 6: Value of var x from source A reappears, because the
# latest update from source B did not contain var x.
self.update_and_verify(inv_src_c, {}, expect={"x": 1}, teststep=6)

View File

@@ -34,18 +34,40 @@ def test_wrapup_does_send_notifications(mocker):
mock.assert_called_once_with('succeeded')
class FakeRedis:
def keys(self, *args, **kwargs):
return []
def set(self):
pass
def get(self):
return None
@classmethod
def from_url(cls, *args, **kwargs):
return cls()
def pipeline(self):
return self
class TestCallbackBrokerWorker(TransactionTestCase):
@pytest.fixture(autouse=True)
def turn_off_websockets_and_redis(self, fake_redis):
def turn_off_websockets(self):
with mock.patch('awx.main.dispatch.worker.callback.emit_event_detail', lambda *a, **kw: None):
yield
def get_worker(self):
with mock.patch('redis.Redis', new=FakeRedis): # turn off redis stuff
return CallbackBrokerWorker()
def event_create_kwargs(self):
inventory_update = InventoryUpdate.objects.create(source='file', inventory_source=InventorySource.objects.create(source='file'))
return dict(inventory_update=inventory_update, created=inventory_update.created)
def test_flush_with_valid_event(self):
worker = CallbackBrokerWorker()
worker = self.get_worker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events}
worker.flush()
@@ -53,7 +75,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
def test_flush_with_invalid_event(self):
worker = CallbackBrokerWorker()
worker = self.get_worker()
kwargs = self.event_create_kwargs()
events = [
InventoryUpdateEvent(uuid=str(uuid4()), stdout='good1', **kwargs),
@@ -68,7 +90,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert worker.buff == {InventoryUpdateEvent: [events[1]]}
def test_duplicate_key_not_saved_twice(self):
worker = CallbackBrokerWorker()
worker = self.get_worker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events.copy()}
worker.flush()
@@ -82,7 +104,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert worker.buff.get(InventoryUpdateEvent, []) == []
def test_give_up_on_bad_event(self):
worker = CallbackBrokerWorker()
worker = self.get_worker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), counter=-2, **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events.copy()}
@@ -95,7 +117,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 0 # sanity
def test_flush_with_empty_buffer(self):
worker = CallbackBrokerWorker()
worker = self.get_worker()
worker.buff = {InventoryUpdateEvent: []}
with mock.patch.object(InventoryUpdateEvent.objects, 'bulk_create') as flush_mock:
worker.flush()
@@ -105,7 +127,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
# In postgres, text fields reject NUL character, 0x00
# tests use sqlite3 which will not raise an error
# but we can still test that it is sanitized before saving
worker = CallbackBrokerWorker()
worker = self.get_worker()
kwargs = self.event_create_kwargs()
events = [InventoryUpdateEvent(uuid=str(uuid4()), stdout="\x00", **kwargs)]
assert "\x00" in events[0].stdout # sanity

View File

@@ -63,33 +63,6 @@ def swagger_autogen(requests=__SWAGGER_REQUESTS__):
return requests
class FakeRedis:
def keys(self, *args, **kwargs):
return []
def set(self):
pass
def get(self):
return None
@classmethod
def from_url(cls, *args, **kwargs):
return cls()
def pipeline(self):
return self
def ping(self):
return
@pytest.fixture
def fake_redis():
with mock.patch('redis.Redis', new=FakeRedis): # turn off redis stuff
yield
@pytest.fixture
def user():
def u(name, is_superuser=False):

View File

@@ -1,56 +0,0 @@
import pytest
from awx.main.migrations._db_constraints import _rename_duplicates
from awx.main.models import JobTemplate
@pytest.mark.django_db
def test_rename_job_template_duplicates(organization, project):
ids = []
for i in range(5):
jt = JobTemplate.objects.create(name=f'jt-{i}', organization=organization, project=project)
ids.append(jt.id) # saved in order of creation
# Hack to first allow duplicate names of JT to test migration
JobTemplate.objects.filter(id__in=ids).update(org_unique=False)
# Set all JTs to the same name
JobTemplate.objects.filter(id__in=ids).update(name='same_name_for_test')
_rename_duplicates(JobTemplate)
first_jt = JobTemplate.objects.get(id=ids[0])
assert first_jt.name == 'same_name_for_test'
for i, pk in enumerate(ids):
if i == 0:
continue
jt = JobTemplate.objects.get(id=pk)
# Name should be set based on creation order
assert jt.name == f'same_name_for_test_dup{i}'
@pytest.mark.django_db
def test_rename_job_template_name_too_long(organization, project):
ids = []
for i in range(3):
jt = JobTemplate.objects.create(name=f'jt-{i}', organization=organization, project=project)
ids.append(jt.id) # saved in order of creation
JobTemplate.objects.filter(id__in=ids).update(org_unique=False)
chars = 512
# Set all JTs to the same reaaaaaaly long name
JobTemplate.objects.filter(id__in=ids).update(name='A' * chars)
_rename_duplicates(JobTemplate)
first_jt = JobTemplate.objects.get(id=ids[0])
assert first_jt.name == 'A' * chars
for i, pk in enumerate(ids):
if i == 0:
continue
jt = JobTemplate.objects.get(id=pk)
assert jt.name.endswith(f'dup{i}')
assert len(jt.name) <= 512

View File

@@ -3,10 +3,6 @@ import pytest
# AWX
from awx.main.ha import is_ha_environment
from awx.main.models.ha import Instance
from awx.main.dispatch.pool import get_auto_max_workers
# Django
from django.test.utils import override_settings
@pytest.mark.django_db
@@ -21,25 +17,3 @@ def test_db_localhost():
Instance.objects.create(hostname='foo', node_type='hybrid')
Instance.objects.create(hostname='bar', node_type='execution')
assert is_ha_environment() is False
@pytest.mark.django_db
@pytest.mark.parametrize(
'settings',
[
dict(SYSTEM_TASK_ABS_MEM='16Gi', SYSTEM_TASK_ABS_CPU='24', SYSTEM_TASK_FORKS_MEM=400, SYSTEM_TASK_FORKS_CPU=4),
dict(SYSTEM_TASK_ABS_MEM='124Gi', SYSTEM_TASK_ABS_CPU='2', SYSTEM_TASK_FORKS_MEM=None, SYSTEM_TASK_FORKS_CPU=None),
],
ids=['cpu_dominated', 'memory_dominated'],
)
def test_dispatcher_max_workers_reserve(settings, fake_redis):
"""This tests that the dispatcher max_workers matches instance capacity
Assumes capacity_adjustment is 1,
plus reserve worker count
"""
with override_settings(**settings):
i = Instance.objects.create(hostname='test-1', node_type='hybrid')
i.local_health_check()
assert get_auto_max_workers() == i.capacity + 7, (i.cpu, i.memory, i.cpu_capacity, i.mem_capacity)

View File

@@ -393,7 +393,7 @@ def test_dependency_isolation(organization):
this should keep dependencies isolated"""
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
updating_projects = [
Project.objects.create(name=f'iso-proj{i}', organization=organization, scm_url='https://foo.invalid', scm_type='git', scm_update_on_launch=True)
Project.objects.create(name='iso-proj', organization=organization, scm_url='https://foo.invalid', scm_type='git', scm_update_on_launch=True)
for i in range(2)
]

View File

@@ -15,17 +15,3 @@ def test_does_not_run_reaped_job(mocker, mock_me):
job.refresh_from_db()
assert job.status == 'failed'
mock_run.assert_not_called()
@pytest.mark.django_db
def test_cancel_flag_on_start(jt_linked, caplog):
job = jt_linked.create_unified_job()
job.status = 'waiting'
job.cancel_flag = True
job.save()
task = RunJob()
task.run(job.id)
job = Job.objects.get(id=job.id)
assert job.status == 'canceled'

View File

@@ -43,7 +43,7 @@ def test_job_template_copy(
c.save()
assert get(reverse('api:job_template_copy', kwargs={'pk': job_template_with_survey_passwords.pk}), alice, expect=200).data['can_copy'] is True
jt_copy_pk_alice = post(
reverse('api:job_template_copy', kwargs={'pk': job_template_with_survey_passwords.pk}), {'name': 'new jt name alice'}, alice, expect=201
reverse('api:job_template_copy', kwargs={'pk': job_template_with_survey_passwords.pk}), {'name': 'new jt name'}, alice, expect=201
).data['id']
jt_copy_admin = type(job_template_with_survey_passwords).objects.get(pk=jt_copy_pk)
@@ -53,7 +53,7 @@ def test_job_template_copy(
assert jt_copy_alice.created_by == alice
for jt_copy in (jt_copy_admin, jt_copy_alice):
assert jt_copy.name.startswith('new jt name')
assert jt_copy.name == 'new jt name'
assert jt_copy.project == project
assert jt_copy.inventory == inventory
assert jt_copy.playbook == job_template_with_survey_passwords.playbook

View File

@@ -5,11 +5,8 @@ import signal
import time
import yaml
from unittest import mock
from copy import deepcopy
from django.utils.timezone import now as tz_now
from django.conf import settings
from django.test.utils import override_settings
import pytest
from awx.main.models import Job, WorkflowJob, Instance
@@ -303,13 +300,6 @@ class TestTaskDispatcher:
class TestTaskPublisher:
@pytest.fixture(autouse=True)
def _disable_dispatcherd(self):
ffs = deepcopy(settings.FLAGS)
ffs['FEATURE_DISPATCHERD_ENABLED'][0]['value'] = False
with override_settings(FLAGS=ffs):
yield
def test_function_callable(self):
assert add(2, 2) == 4

View File

@@ -209,7 +209,7 @@ def test_inventory_update_injected_content(product_name, this_kind, inventory, f
source_vars=src_vars,
)
inventory_source.credentials.add(fake_credential_factory(this_kind))
inventory_update = inventory_source.create_unified_job(_eager_fields={'status': 'waiting'})
inventory_update = inventory_source.create_unified_job()
task = RunInventoryUpdate()
def substitute_run(awx_receptor_job):

View File

@@ -19,7 +19,7 @@ from awx.main.models import (
ExecutionEnvironment,
)
from awx.main.tasks.system import cluster_node_heartbeat
from awx.main.utils.db import bulk_update_sorted_by_id
from awx.main.tasks.facts import update_hosts
from django.db import OperationalError
from django.test.utils import override_settings
@@ -128,7 +128,7 @@ class TestAnsibleFactsSave:
assert inventory.hosts.count() == 3
Host.objects.get(pk=last_pk).delete()
assert inventory.hosts.count() == 2
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
update_hosts(hosts)
assert inventory.hosts.count() == 2
for host in inventory.hosts.all():
host.refresh_from_db()
@@ -141,7 +141,7 @@ class TestAnsibleFactsSave:
db_mock = mocker.patch('awx.main.tasks.facts.Host.objects.bulk_update')
db_mock.side_effect = OperationalError('deadlock detected')
with pytest.raises(OperationalError):
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
update_hosts(hosts)
def fake_bulk_update(self, host_list):
if self.current_call > 2:
@@ -149,28 +149,16 @@ class TestAnsibleFactsSave:
self.current_call += 1
raise OperationalError('deadlock detected')
@pytest.mark.django_db
def test_update_hosts_resolved_deadlock(inventory, mocker):
hosts = [Host.objects.create(inventory=inventory, name=f'foo{i}') for i in range(3)]
# Set ansible_facts for each host
for host in hosts:
host.ansible_facts = {'foo': 'bar'}
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
# Save changes and refresh from DB to ensure the updated facts are saved
for host in hosts:
host.save() # Ensure changes are persisted in the DB
host.refresh_from_db() # Refresh from DB to get latest data
# Assert that the ansible_facts were updated correctly
for host in inventory.hosts.all():
assert host.ansible_facts == {'foo': 'bar'}
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
def test_update_hosts_resolved_deadlock(self, inventory, mocker):
hosts = [Host.objects.create(inventory=inventory, name=f'foo{i}') for i in range(3)]
for host in hosts:
host.ansible_facts = {'foo': 'bar'}
self.current_call = 0
mocker.patch('awx.main.tasks.facts.raw_update_hosts', new=self.fake_bulk_update)
update_hosts(hosts)
for host in inventory.hosts.all():
host.refresh_from_db()
assert host.ansible_facts == {'foo': 'bar'}
@pytest.mark.django_db

View File

@@ -47,7 +47,6 @@ def index_licenses(path):
def parse_requirement(reqt):
parsed_requirement = parse_req_from_line(reqt.requirement, None)
assert parsed_requirement.requirement, reqt.__dict__
name = parsed_requirement.requirement.name
version = str(parsed_requirement.requirement.specifier)
if version.startswith('=='):

View File

@@ -106,37 +106,3 @@ class TestMigrationSmoke:
)
DABPermission = new_state.apps.get_model('dab_rbac', 'DABPermission')
assert not DABPermission.objects.filter(codename='view_executionenvironment').exists()
# Test create a Project with a duplicate name
Organization = new_state.apps.get_model('main', 'Organization')
Project = new_state.apps.get_model('main', 'Project')
org = Organization.objects.create(name='duplicate-obj-organization', created=now(), modified=now())
proj_ids = []
for i in range(3):
proj = Project.objects.create(name='duplicate-project-name', organization=org, created=now(), modified=now())
proj_ids.append(proj.id)
# The uniqueness rules will not apply to InventorySource
Inventory = new_state.apps.get_model('main', 'Inventory')
InventorySource = new_state.apps.get_model('main', 'InventorySource')
inv = Inventory.objects.create(name='migration-test-inv', organization=org, created=now(), modified=now())
InventorySource.objects.create(name='migration-test-src', source='file', inventory=inv, organization=org, created=now(), modified=now())
new_state = migrator.apply_tested_migration(
('main', '0200_template_name_constraint'),
)
for i, proj_id in enumerate(proj_ids):
proj = Project.objects.get(id=proj_id)
if i == 0:
assert proj.name == 'duplicate-project-name'
else:
assert proj.name != 'duplicate-project-name'
assert proj.name.startswith('duplicate-project-name')
# The inventory source had this field set to avoid the constrains
InventorySource = new_state.apps.get_model('main', 'InventorySource')
inv_src = InventorySource.objects.get(name='migration-test-src')
assert inv_src.org_unique is False
Project = new_state.apps.get_model('main', 'Project')
for proj in Project.objects.all():
assert proj.org_unique is True

View File

@@ -1,631 +0,0 @@
import json
import os
from unittest import mock
import pytest
import requests.exceptions
from django.test import override_settings
from awx.main.models import (
Job,
Inventory,
Project,
Organization,
JobTemplate,
Credential,
CredentialType,
User,
Team,
Label,
WorkflowJob,
WorkflowJobNode,
InventorySource,
)
from awx.main.exceptions import PolicyEvaluationError
from awx.main.tasks import policy
from awx.main.tasks.policy import JobSerializer, OPA_AUTH_TYPES
def _parse_exception_message(exception: PolicyEvaluationError):
pe_plain = str(exception.value)
assert "This job cannot be executed due to a policy violation or error. See the following details:" in pe_plain
violation_message = "This job cannot be executed due to a policy violation or error. See the following details:"
return eval(pe_plain.split(violation_message)[1].strip())
@pytest.fixture(autouse=True)
def setup_opa_settings():
with override_settings(
OPA_HOST='opa.example.com',
):
yield
@pytest.fixture
def opa_client():
cls_mock = mock.MagicMock(name='OpaClient')
instance_mock = cls_mock.return_value
instance_mock.__enter__.return_value = instance_mock
with mock.patch('awx.main.tasks.policy.OpaClient', cls_mock):
yield instance_mock
@pytest.fixture
def job():
project: Project = Project.objects.create(name='proj1', scm_type='git', scm_branch='main', scm_url='https://git.example.com/proj1')
inventory: Inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org: Organization = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt: JobTemplate = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job: Job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
return job
@pytest.mark.django_db
def test_job_serializer():
user: User = User.objects.create(username='user1')
org: Organization = Organization.objects.create(name='org1')
team: Team = Team.objects.create(name='team1', organization=org)
team.admin_role.members.add(user)
project: Project = Project.objects.create(name='proj1', scm_type='git', scm_branch='main', scm_url='https://git.example.com/proj1')
inventory: Inventory = Inventory.objects.create(name='inv1', description='Demo inventory')
inventory_source: InventorySource = InventorySource.objects.create(name='inv-src1', source='file', inventory=inventory)
extra_vars = {"FOO": "value1", "BAR": "value2"}
CredentialType.setup_tower_managed_defaults()
cred_type_ssh: CredentialType = CredentialType.objects.get(kind='ssh')
cred: Credential = Credential.objects.create(name="cred1", description='Demo credential', credential_type=cred_type_ssh, organization=org)
label: Label = Label.objects.create(name='label1', organization=org)
job: Job = Job.objects.create(
name='job1', extra_vars=json.dumps(extra_vars), inventory=inventory, project=project, organization=org, created_by=user, launch_type='workflow'
)
# job.unified_job_node.workflow_job = workflow_job
job.credentials.add(cred)
job.labels.add(label)
workflow_job: WorkflowJob = WorkflowJob.objects.create(name='wf-job1')
WorkflowJobNode.objects.create(job=job, workflow_job=workflow_job)
serializer = JobSerializer(instance=job)
assert serializer.data == {
'id': job.id,
'name': 'job1',
'created': job.created.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
'created_by': {
'id': user.id,
'username': 'user1',
'is_superuser': False,
'teams': [
{'id': team.id, 'name': 'team1'},
],
},
'credentials': [
{
'id': cred.id,
'name': 'cred1',
'description': 'Demo credential',
'organization': {
'id': org.id,
'name': 'org1',
},
'credential_type': cred_type_ssh.id,
'kind': 'ssh',
'managed': False,
'kubernetes': False,
'cloud': False,
},
],
'execution_environment': None,
'extra_vars': extra_vars,
'forks': 0,
'hosts_count': 0,
'instance_group': None,
'inventory': {
'id': inventory.id,
'name': 'inv1',
'description': 'Demo inventory',
'kind': '',
'total_hosts': 0,
'total_groups': 0,
'has_inventory_sources': False,
'total_inventory_sources': 0,
'has_active_failures': False,
'hosts_with_active_failures': 0,
'inventory_sources': [
{
'id': inventory_source.id,
'name': 'inv-src1',
'source': 'file',
'status': 'never updated',
}
],
},
'job_template': None,
'job_type': 'run',
'job_type_name': 'job',
'labels': [
{
'id': label.id,
'name': 'label1',
'organization': {
'id': org.id,
'name': 'org1',
},
},
],
'launch_type': 'workflow',
'limit': '',
'launched_by': {},
'organization': {
'id': org.id,
'name': 'org1',
},
'playbook': '',
'project': {
'id': project.id,
'name': 'proj1',
'status': 'pending',
'scm_type': 'git',
'scm_url': 'https://git.example.com/proj1',
'scm_branch': 'main',
'scm_refspec': '',
'scm_clean': False,
'scm_track_submodules': False,
'scm_delete_on_update': False,
},
'scm_branch': '',
'scm_revision': '',
'workflow_job': {
'id': workflow_job.id,
'name': 'wf-job1',
},
'workflow_job_template': None,
}
@pytest.mark.django_db
def test_evaluate_policy_missing_opa_query_path_field(opa_client):
project: Project = Project.objects.create(name='proj1', scm_type='git', scm_branch='main', scm_url='https://git.example.com/proj1')
inventory: Inventory = Inventory.objects.create(name='inv1')
org: Organization = Organization.objects.create(name="org1")
jt: JobTemplate = JobTemplate.objects.create(name="jt1")
job: Job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
response = {
"result": {
"allowed": True,
"violations": [],
}
}
opa_client.query_rule.return_value = response
try:
policy.evaluate_policy(job)
except PolicyEvaluationError as e:
pytest.fail(f"Must not raise PolicyEvaluationError: {e}")
assert opa_client.query_rule.call_count == 0
@pytest.mark.django_db
def test_evaluate_policy(opa_client, job):
response = {
"result": {
"allowed": True,
"violations": [],
}
}
opa_client.query_rule.return_value = response
try:
policy.evaluate_policy(job)
except PolicyEvaluationError as e:
pytest.fail(f"Must not raise PolicyEvaluationError: {e}")
opa_client.query_rule.assert_has_calls(
[
mock.call(input_data=mock.ANY, package_path='organization/response'),
mock.call(input_data=mock.ANY, package_path='inventory/response'),
mock.call(input_data=mock.ANY, package_path='job_template/response'),
],
any_order=False,
)
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_allowed(opa_client, job):
response = {
"result": {
"allowed": True,
"violations": [],
}
}
opa_client.query_rule.return_value = response
try:
policy.evaluate_policy(job)
except PolicyEvaluationError as e:
pytest.fail(f"Must not raise PolicyEvaluationError: {e}")
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_not_allowed(opa_client, job):
response = {
"result": {
"allowed": False,
"violations": ["Access not allowed."],
}
}
opa_client.query_rule.return_value = response
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
pe_plain = str(pe.value)
assert "Errors:" not in pe_plain
exception = _parse_exception_message(pe)
assert exception["Violations"]["Organization"] == ["Access not allowed."]
assert exception["Violations"]["Inventory"] == ["Access not allowed."]
assert exception["Violations"]["Job template"] == ["Access not allowed."]
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_not_found(opa_client, job):
response = {}
opa_client.query_rule.return_value = response
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
missing_result_property = 'Call to OPA did not return a "result" property. The path refers to an undefined document.'
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == missing_result_property
assert exception["Errors"]["Inventory"] == missing_result_property
assert exception["Errors"]["Job template"] == missing_result_property
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_server_error(opa_client, job):
http_error_msg = '500 Server Error: Internal Server Error for url: https://opa.example.com:8181/v1/data/job_template/response/invalid'
error_response = {
'code': 'internal_error',
'message': (
'1 error occurred: 1:1: rego_type_error: undefined ref: data.job_template.response.invalid\n\t'
'data.job_template.response.invalid\n\t'
' ^\n\t'
' have: "invalid"\n\t'
' want (one of): ["allowed" "violations"]'
),
}
response = mock.Mock()
response.status_code = requests.codes.internal_server_error
response.json.return_value = error_response
opa_client.query_rule.side_effect = requests.exceptions.HTTPError(http_error_msg, response=response)
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == f'Call to OPA failed. Code: internal_error, Message: {error_response["message"]}'
assert exception["Errors"]["Inventory"] == f'Call to OPA failed. Code: internal_error, Message: {error_response["message"]}'
assert exception["Errors"]["Job template"] == f'Call to OPA failed. Code: internal_error, Message: {error_response["message"]}'
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_invalid_result(opa_client, job):
response = {
"result": {
"absolutely": "no!",
}
}
opa_client.query_rule.return_value = response
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
invalid_result = 'OPA policy returned invalid result.'
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == invalid_result
assert exception["Errors"]["Inventory"] == invalid_result
assert exception["Errors"]["Job template"] == invalid_result
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
def test_evaluate_policy_failed_exception(opa_client, job):
error_response = {}
response = mock.Mock()
response.status_code = requests.codes.internal_server_error
response.json.return_value = error_response
opa_client.query_rule.side_effect = ValueError("Invalid JSON")
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
opa_failed_exception = 'Call to OPA failed. Exception: Invalid JSON'
exception = _parse_exception_message(pe)
assert exception["Errors"]["Organization"] == opa_failed_exception
assert exception["Errors"]["Inventory"] == opa_failed_exception
assert exception["Errors"]["Job template"] == opa_failed_exception
assert opa_client.query_rule.call_count == 3
@pytest.mark.django_db
@pytest.mark.parametrize(
"settings_kwargs, expected_client_cert, expected_verify, verify_content",
[
# Case 1: Certificate-based authentication (mTLS)
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": True,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.CERTIFICATE,
"OPA_AUTH_CLIENT_CERT": "-----BEGIN CERTIFICATE-----\nMIICert\n-----END CERTIFICATE-----",
"OPA_AUTH_CLIENT_KEY": "-----BEGIN PRIVATE KEY-----\nMIIKey\n-----END PRIVATE KEY-----",
"OPA_AUTH_CA_CERT": "-----BEGIN CERTIFICATE-----\nMIICACert\n-----END CERTIFICATE-----",
},
True, # Client cert should be created
"file", # Verify path should be a file
"-----BEGIN CERTIFICATE-----", # Expected content in verify file
),
# Case 2: SSL with server verification only
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": True,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.NONE,
"OPA_AUTH_CA_CERT": "-----BEGIN CERTIFICATE-----\nMIICACert\n-----END CERTIFICATE-----",
},
False, # No client cert should be created
"file", # Verify path should be a file
"-----BEGIN CERTIFICATE-----", # Expected content in verify file
),
# Case 3: SSL with system CA store
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": True,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.NONE,
"OPA_AUTH_CA_CERT": "", # No custom CA cert
},
False, # No client cert should be created
True, # Verify path should be True (system CA store)
None, # No file to check content
),
# Case 4: No SSL
(
{
"OPA_HOST": "opa.example.com",
"OPA_SSL": False,
"OPA_AUTH_TYPE": OPA_AUTH_TYPES.NONE,
},
False, # No client cert should be created
False, # Verify path should be False (no verification)
None, # No file to check content
),
],
ids=[
"certificate_auth",
"ssl_server_verification",
"ssl_system_ca_store",
"no_ssl",
],
)
def test_opa_cert_file(settings_kwargs, expected_client_cert, expected_verify, verify_content):
"""Parameterized test for the opa_cert_file context manager.
Tests different configurations:
- Certificate-based authentication (mTLS)
- SSL with server verification only
- SSL with system CA store
- No SSL
"""
with override_settings(**settings_kwargs):
client_cert_path = None
verify_path = None
with policy.opa_cert_file() as cert_files:
client_cert_path, verify_path = cert_files
# Check client cert based on expected_client_cert
if expected_client_cert:
assert client_cert_path is not None
with open(client_cert_path, 'r') as f:
content = f.read()
assert "-----BEGIN CERTIFICATE-----" in content
assert "-----BEGIN PRIVATE KEY-----" in content
else:
assert client_cert_path is None
# Check verify path based on expected_verify
if expected_verify == "file":
assert verify_path is not None
assert os.path.isfile(verify_path)
with open(verify_path, 'r') as f:
content = f.read()
assert verify_content in content
else:
assert verify_path is expected_verify
# Verify files are deleted after context manager exits
if expected_client_cert:
assert not os.path.exists(client_cert_path), "Client cert file was not deleted"
if expected_verify == "file":
assert not os.path.exists(verify_path), "CA cert file was not deleted"
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_SSL=False, # SSL disabled
OPA_AUTH_TYPE=OPA_AUTH_TYPES.CERTIFICATE, # But cert auth enabled
OPA_AUTH_CLIENT_CERT="-----BEGIN CERTIFICATE-----\nMIICert\n-----END CERTIFICATE-----",
OPA_AUTH_CLIENT_KEY="-----BEGIN PRIVATE KEY-----\nMIIKey\n-----END PRIVATE KEY-----",
)
def test_evaluate_policy_cert_auth_requires_ssl():
"""Test that policy evaluation raises an error when certificate auth is used without SSL."""
project = Project.objects.create(name='proj1')
inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
assert "OPA_AUTH_TYPE=Certificate requires OPA_SSL to be enabled" in str(pe.value)
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_SSL=True,
OPA_AUTH_TYPE=OPA_AUTH_TYPES.CERTIFICATE,
OPA_AUTH_CLIENT_CERT="", # Missing client cert
OPA_AUTH_CLIENT_KEY="", # Missing client key
OPA_AUTH_CA_CERT="", # Missing CA cert
)
def test_evaluate_policy_missing_cert_settings():
"""Test that policy evaluation raises an error when certificate settings are missing."""
project = Project.objects.create(name='proj1')
inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
with pytest.raises(PolicyEvaluationError) as pe:
policy.evaluate_policy(job)
error_msg = str(pe.value)
assert "Following certificate settings are missing for OPA_AUTH_TYPE=Certificate:" in error_msg
assert "OPA_AUTH_CLIENT_CERT" in error_msg
assert "OPA_AUTH_CLIENT_KEY" in error_msg
assert "OPA_AUTH_CA_CERT" in error_msg
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_PORT=8181,
OPA_SSL=True,
OPA_AUTH_TYPE=OPA_AUTH_TYPES.CERTIFICATE,
OPA_AUTH_CLIENT_CERT="-----BEGIN CERTIFICATE-----\nMIICert\n-----END CERTIFICATE-----",
OPA_AUTH_CLIENT_KEY="-----BEGIN PRIVATE KEY-----\nMIIKey\n-----END PRIVATE KEY-----",
OPA_AUTH_CA_CERT="-----BEGIN CERTIFICATE-----\nMIICACert\n-----END CERTIFICATE-----",
OPA_REQUEST_TIMEOUT=2.5,
OPA_REQUEST_RETRIES=3,
)
def test_opa_client_context_manager_mtls():
"""Test that opa_client context manager correctly initializes the OPA client."""
# Mock the OpaClient class
with mock.patch('awx.main.tasks.policy.OpaClient') as mock_opa_client:
# Setup the mock
mock_instance = mock_opa_client.return_value
mock_instance.__enter__.return_value = mock_instance
mock_instance._session = mock.MagicMock()
# Use the context manager
with policy.opa_client(headers={'Custom-Header': 'Value'}) as client:
# Verify the client was initialized with the correct parameters
mock_opa_client.assert_called_once_with(
host='opa.example.com',
port=8181,
headers={'Custom-Header': 'Value'},
ssl=True,
cert=mock.ANY, # We can't check the exact value as it's a temporary file
timeout=2.5,
retries=3,
)
# Verify the session properties were set correctly
assert client._session.cert is not None
assert client._session.verify is not None
# Check the content of the cert file
cert_file_path = client._session.cert
assert os.path.isfile(cert_file_path)
with open(cert_file_path, 'r') as f:
cert_content = f.read()
assert "-----BEGIN CERTIFICATE-----" in cert_content
assert "MIICert" in cert_content
assert "-----BEGIN PRIVATE KEY-----" in cert_content
assert "MIIKey" in cert_content
# Check the content of the verify file
verify_file_path = client._session.verify
assert os.path.isfile(verify_file_path)
with open(verify_file_path, 'r') as f:
verify_content = f.read()
assert "-----BEGIN CERTIFICATE-----" in verify_content
assert "MIICACert" in verify_content
# Verify the client is the mocked instance
assert client is mock_instance
# Store file paths for checking after context exit
cert_path = client._session.cert
verify_path = client._session.verify
# Verify files are deleted after context manager exits
assert not os.path.exists(cert_path), "Client cert file was not deleted"
assert not os.path.exists(verify_path), "CA cert file was not deleted"
@pytest.mark.django_db
@override_settings(
OPA_HOST='opa.example.com',
OPA_SSL=True,
OPA_AUTH_TYPE=OPA_AUTH_TYPES.TOKEN,
OPA_AUTH_TOKEN='secret-token',
OPA_AUTH_CUSTOM_HEADERS={'X-Custom': 'Header'},
)
def test_opa_client_token_auth():
"""Test that token authentication correctly adds the Authorization header."""
# Create a job for testing
project = Project.objects.create(name='proj1')
inventory = Inventory.objects.create(name='inv1', opa_query_path="inventory/response")
org = Organization.objects.create(name="org1", opa_query_path="organization/response")
jt = JobTemplate.objects.create(name="jt1", opa_query_path="job_template/response")
job = Job.objects.create(name='job1', extra_vars="{}", inventory=inventory, project=project, organization=org, job_template=jt)
# Mock the OpaClient class
with mock.patch('awx.main.tasks.policy.opa_client') as mock_opa_client_cm:
# Setup the mock
mock_client = mock.MagicMock()
mock_opa_client_cm.return_value.__enter__.return_value = mock_client
mock_client.query_rule.return_value = {
"result": {
"allowed": True,
"violations": [],
}
}
# Call evaluate_policy
policy.evaluate_policy(job)
# Verify opa_client was called with the correct headers
expected_headers = {'X-Custom': 'Header', 'Authorization': 'Bearer secret-token'}
mock_opa_client_cm.assert_called_once_with(headers=expected_headers)

View File

@@ -436,22 +436,21 @@ def test_project_list_ordering_by_name(get, order_by, expected_names, organizati
@pytest.mark.parametrize('order_by', ('name', '-name'))
@pytest.mark.django_db
def test_project_list_ordering_with_duplicate_names(get, order_by, admin):
def test_project_list_ordering_with_duplicate_names(get, order_by, organization_factory):
# why? because all the '1' mean that all the names are the same, you can't sort based on that,
# meaning you have to fall back on the default sort order, which in this case, is ID
'ensure sorted order of project list is maintained correctly when the project names the same'
from awx.main.models import Organization
projects = []
for i in range(5):
projects.append(Project.objects.create(name='1', organization=Organization.objects.create(name=f'org{i}')))
objects = organization_factory(
'org1',
projects=['1', '1', '1', '1', '1'],
superusers=['admin'],
)
project_ids = {}
for x in range(3):
results = get(reverse('api:project_list'), user=admin, QUERY_STRING='order_by=%s' % order_by).data['results']
results = get(reverse('api:project_list'), objects.superusers.admin, QUERY_STRING='order_by=%s' % order_by).data['results']
project_ids[x] = [proj['id'] for proj in results]
assert project_ids[0] == project_ids[1] == project_ids[2]
assert project_ids[0] == sorted(project_ids[0])
assert set(project_ids[0]) == set([proj.id for proj in projects])
@pytest.mark.django_db

View File

@@ -36,7 +36,7 @@ def test_bootstrap_consistent():
assert not different_requirements
@pytest.mark.xfail(reason="This test needs some love")
@pytest.mark.skip(reason="This test needs some love")
def test_env_matches_requirements_txt():
from pip.operations import freeze

View File

@@ -74,7 +74,7 @@ class TestWebsocketEventConsumer:
connected, _ = await server.connect()
assert connected is False, "Anonymous user should NOT be allowed to login."
@pytest.mark.xfail(reason="Ran out of coding time.")
@pytest.mark.skip(reason="Ran out of coding time.")
async def test_authorized(self, websocket_server_generator, application, admin):
server = websocket_server_generator('/websocket/')

View File

@@ -1,77 +0,0 @@
import multiprocessing
import json
import pytest
import requests
from requests.auth import HTTPBasicAuth
from django.db import connection
from awx.main.models import User, JobTemplate
def create_in_subprocess(project_id, ready_event, continue_event, admin_auth):
connection.connect()
print('setting ready event')
ready_event.set()
print('waiting for continue event')
continue_event.wait()
if JobTemplate.objects.filter(name='test_jt_duplicate_name').exists():
for jt in JobTemplate.objects.filter(name='test_jt_duplicate_name'):
jt.delete()
assert JobTemplate.objects.filter(name='test_jt_duplicate_name').count() == 0
jt_data = {'name': 'test_jt_duplicate_name', 'project': project_id, 'playbook': 'hello_world.yml', 'ask_inventory_on_launch': True}
response = requests.post('http://localhost:8013/api/v2/job_templates/', json=jt_data, auth=admin_auth)
# should either have a conflict or create
assert response.status_code in (400, 201)
print(f'Subprocess got {response.status_code}')
if response.status_code == 400:
print(json.dumps(response.json(), indent=2))
return response.status_code
@pytest.fixture
def admin_for_test():
user, created = User.objects.get_or_create(username='admin_for_test', defaults={'is_superuser': True})
if created:
user.set_password('for_test_123!')
user.save()
print(f'Created user {user.username}')
return user
@pytest.fixture
def admin_auth(admin_for_test):
return HTTPBasicAuth(admin_for_test.username, 'for_test_123!')
def test_jt_duplicate_name(admin_auth, demo_proj):
N_processes = 5
ready_events = [multiprocessing.Event() for _ in range(N_processes)]
continue_event = multiprocessing.Event()
processes = []
for i in range(N_processes):
p = multiprocessing.Process(target=create_in_subprocess, args=(demo_proj.id, ready_events[i], continue_event, admin_auth))
processes.append(p)
p.start()
# Assure both processes are connected and have loaded their host list
for e in ready_events:
print('waiting on subprocess ready event')
e.wait()
# Begin the bulk_update queries
print('setting the continue event for the workers')
continue_event.set()
# if a Deadloack happens it will probably be surfaced by result here
print('waiting on the workers to finish the creation')
for p in processes:
p.join()
assert JobTemplate.objects.filter(name='test_jt_duplicate_name').count() == 1

View File

@@ -3,7 +3,6 @@ import time
import os
import shutil
import tempfile
import logging
import pytest
@@ -20,9 +19,6 @@ from awx.main.tests import data
from awx.main.models import Project, JobTemplate, Organization, Inventory
logger = logging.getLogger(__name__)
PROJ_DATA = os.path.join(os.path.dirname(data.__file__), 'projects')
@@ -114,12 +110,6 @@ def demo_inv(default_org):
return inventory
@pytest.fixture(scope='session')
def demo_proj(default_org):
proj, _ = Project.objects.get_or_create(name='Demo Project', defaults={'organization': default_org})
return proj
@pytest.fixture
def podman_image_generator():
"""
@@ -138,29 +128,30 @@ def podman_image_generator():
@pytest.fixture
def project_factory(post, default_org, admin):
def _rf(scm_url=None, local_path=None):
proj_kwargs = {}
def run_job_from_playbook(default_org, demo_inv, post, admin):
def _rf(test_name, playbook, local_path=None, scm_url=None, jt_params=None):
project_name = f'{test_name} project'
jt_name = f'{test_name} JT: {playbook}'
old_proj = Project.objects.filter(name=project_name).first()
if old_proj:
old_proj.delete()
old_jt = JobTemplate.objects.filter(name=jt_name).first()
if old_jt:
old_jt.delete()
proj_kwargs = {'name': project_name, 'organization': default_org.id}
if local_path:
# manual path
project_name = f'Manual roject {local_path}'
proj_kwargs['scm_type'] = ''
proj_kwargs['local_path'] = local_path
elif scm_url:
project_name = f'Project {scm_url}'
proj_kwargs['scm_type'] = 'git'
proj_kwargs['scm_url'] = scm_url
else:
raise RuntimeError('Need to provide scm_url or local_path')
proj_kwargs['name'] = project_name
proj_kwargs['organization'] = default_org.id
old_proj = Project.objects.filter(name=project_name).first()
if old_proj:
logger.info(f'Deleting existing project {project_name}')
old_proj.delete()
result = post(
reverse('api:project_list'),
proj_kwargs,
@@ -168,23 +159,6 @@ def project_factory(post, default_org, admin):
expect=201,
)
proj = Project.objects.get(id=result.data['id'])
return proj
return _rf
@pytest.fixture
def run_job_from_playbook(demo_inv, post, admin, project_factory):
def _rf(test_name, playbook, local_path=None, scm_url=None, jt_params=None, proj=None, wait=True):
jt_name = f'{test_name} JT: {playbook}'
if not proj:
proj = project_factory(scm_url=scm_url, local_path=local_path)
old_jt = JobTemplate.objects.filter(name=jt_name).first()
if old_jt:
logger.info(f'Deleting existing JT {jt_name}')
old_jt.delete()
if proj.current_job:
wait_for_job(proj.current_job)
@@ -206,9 +180,7 @@ def run_job_from_playbook(demo_inv, post, admin, project_factory):
job = jt.create_unified_job()
job.signal_start()
if wait:
wait_for_job(job)
assert job.status == 'successful'
return {'job': job, 'job_template': jt, 'project': proj}
wait_for_job(job)
assert job.status == 'successful'
return _rf

View File

@@ -1,74 +0,0 @@
import time
from dispatcherd.config import settings
from dispatcherd.factories import get_control_from_settings
from dispatcherd.utils import serialize_task
from awx.main.models import JobTemplate
from awx.main.tests.data.sleep_task import sleep_break_connection, advisory_lock_exception
from awx.main.tests.live.tests.conftest import wait_for_job
def poll_for_task_finish(task_name):
running_tasks = [1]
start = time.monotonic()
ctl = get_control_from_settings()
while running_tasks:
responses = ctl.control_with_reply('running')
assert len(responses) == 1
response = responses[0]
response.pop('node_id')
running_tasks = [task_data for task_data in response.values() if task_data['task'] == task_name]
if time.monotonic() - start > 5.0:
assert False, f'Never finished working through tasks: {running_tasks}'
def check_jobs_work():
jt = JobTemplate.objects.get(name='Demo Job Template')
job = jt.create_unified_job()
job.signal_start()
wait_for_job(job)
def test_advisory_lock_error_clears():
"""Run a task that has an exception while holding advisory_lock
This is regression testing for a bug in its exception handling
expected to be fixed by
https://github.com/ansible/django-ansible-base/pull/713
This is an "easier" test case than the next,
because it passes just by fixing the DAB case,
and passing this does not generally guarentee that
workers will not be left with a connection in a bad state.
"""
min_workers = settings.service['pool_kwargs']['min_workers']
for i in range(min_workers):
advisory_lock_exception.delay()
task_name = serialize_task(advisory_lock_exception)
poll_for_task_finish(task_name)
# Jobs should still work even after the breaking task has ran
check_jobs_work()
def test_can_recover_connection():
"""Run a task that intentionally times out the worker connection
If no connection fixing is implemented outside of that task scope,
then subsequent tasks will all error, thus checking that jobs run,
after running the sleep_break_connection task.
"""
min_workers = settings.service['pool_kwargs']['min_workers']
for i in range(min_workers):
sleep_break_connection.delay()
task_name = serialize_task(sleep_break_connection)
poll_for_task_finish(task_name)
# Jobs should still work even after the breaking task has ran
check_jobs_work()

View File

@@ -1,20 +1,14 @@
import pytest
from awx.main.tests.live.tests.conftest import wait_for_events, wait_for_job
from awx.main.tests.live.tests.conftest import wait_for_events
from awx.main.models import Job, Inventory
@pytest.fixture
def facts_project(live_tmp_folder, project_factory):
return project_factory(scm_url=f'file://{live_tmp_folder}/facts')
def assert_facts_populated(name):
job = Job.objects.filter(name__icontains=name).order_by('-created').first()
assert job is not None
wait_for_events(job)
wait_for_job(job)
inventory = job.inventory
assert inventory.hosts.count() > 0 # sanity
@@ -23,24 +17,24 @@ def assert_facts_populated(name):
@pytest.fixture
def general_facts_test(facts_project, run_job_from_playbook):
def general_facts_test(live_tmp_folder, run_job_from_playbook):
def _rf(slug, jt_params):
jt_params['use_fact_cache'] = True
standard_kwargs = dict(jt_params=jt_params)
standard_kwargs = dict(scm_url=f'file://{live_tmp_folder}/facts', jt_params=jt_params)
# GATHER FACTS
name = f'test_gather_ansible_facts_{slug}'
run_job_from_playbook(name, 'gather.yml', proj=facts_project, **standard_kwargs)
run_job_from_playbook(name, 'gather.yml', **standard_kwargs)
assert_facts_populated(name)
# KEEP FACTS
name = f'test_clear_ansible_facts_{slug}'
run_job_from_playbook(name, 'no_op.yml', proj=facts_project, **standard_kwargs)
run_job_from_playbook(name, 'no_op.yml', **standard_kwargs)
assert_facts_populated(name)
# CLEAR FACTS
name = f'test_clear_ansible_facts_{slug}'
run_job_from_playbook(name, 'clear.yml', proj=facts_project, **standard_kwargs)
run_job_from_playbook(name, 'clear.yml', **standard_kwargs)
job = Job.objects.filter(name__icontains=name).order_by('-created').first()
assert job is not None

View File

@@ -1,78 +0,0 @@
import multiprocessing
import random
from django.db import connection
from django.utils.timezone import now
from awx.main.models import Inventory, Host
from awx.main.utils.db import bulk_update_sorted_by_id
def worker_delete_target(ready_event, continue_event, field_name):
"""Runs the bulk update, will be called in duplicate, in parallel"""
inv = Inventory.objects.get(organization__name='Default', name='test_host_update_contention')
host_list = list(inv.hosts.all())
# Using random.shuffle for non-security-critical shuffling in a test
random.shuffle(host_list) # NOSONAR
for i, host in enumerate(host_list):
setattr(host, field_name, f'my_var: {i}')
# ready to do the bulk_update
print('worker has loaded all the hosts needed')
ready_event.set()
# wait for the coordination message
continue_event.wait()
# NOTE: did not reproduce the bug without batch_size
bulk_update_sorted_by_id(Host, host_list, fields=[field_name], batch_size=100)
print('finished doing the bulk update in worker')
def test_host_update_contention(default_org):
inv_kwargs = dict(organization=default_org, name='test_host_update_contention')
if Inventory.objects.filter(**inv_kwargs).exists():
inv = Inventory.objects.get(**inv_kwargs).delete()
inv = Inventory.objects.create(**inv_kwargs)
right_now = now()
hosts = [Host(inventory=inv, name=f'host-{i}', created=right_now, modified=right_now) for i in range(1000)]
print('bulk creating hosts')
Host.objects.bulk_create(hosts)
# sanity check
for host in hosts:
assert not host.variables
# Force our worker pool to make their own connection
connection.close()
ready_events = [multiprocessing.Event() for _ in range(2)]
continue_event = multiprocessing.Event()
print('spawning processes for concurrent bulk updates')
processes = []
fields = ['variables', 'ansible_facts']
for i in range(2):
p = multiprocessing.Process(target=worker_delete_target, args=(ready_events[i], continue_event, fields[i]))
processes.append(p)
p.start()
# Assure both processes are connected and have loaded their host list
for e in ready_events:
print('waiting on subprocess ready event')
e.wait()
# Begin the bulk_update queries
print('setting the continue event for the workers')
continue_event.set()
# if a Deadloack happens it will probably be surfaced by result here
print('waiting on the workers to finish the bulk_update')
for p in processes:
p.join()
print('checking workers have variables set')
for host in inv.hosts.all():
assert host.variables.startswith('my_var:')
assert host.ansible_facts.startswith('my_var:')

View File

@@ -1,224 +0,0 @@
import subprocess
import time
import os.path
from urllib.parse import urlsplit
import pytest
from unittest import mock
from awx.main.models.projects import Project
from awx.main.models.organization import Organization
from awx.main.models.inventory import Inventory, InventorySource
from awx.main.tests.live.tests.conftest import wait_for_job
NAME_PREFIX = "test-ivu"
GIT_REPO_FOLDER = "inventory_vars"
def create_new_by_name(model, **kwargs):
"""
Create a new model instance. Delete an existing instance first.
:param model: The Django model.
:param dict kwargs: The keyword arguments required to create a model
instance. Must contain at least `name`.
:return: The model instance.
"""
name = kwargs["name"]
try:
instance = model.objects.get(name=name)
except model.DoesNotExist:
pass
else:
print(f"FORCE DELETE {name}")
instance.delete()
finally:
instance = model.objects.create(**kwargs)
return instance
def wait_for_update(instance, timeout=3.0):
"""Wait until the last update of *instance* is finished."""
start = time.time()
while time.time() - start < timeout:
if instance.current_job or instance.last_job or instance.last_job_run:
break
time.sleep(0.2)
assert instance.current_job or instance.last_job or instance.last_job_run, f'Instance never updated id={instance.id}'
update = instance.current_job or instance.last_job
if update:
wait_for_job(update)
def change_source_vars_and_update(invsrc, group_vars):
"""
Change the variables content of an inventory source and update its
inventory.
Does not return before the inventory update is finished.
:param invsrc: The inventory source instance.
:param dict group_vars: The variables for various groups. Format::
{
<group>: {<variable>: <value>, <variable>: <value>, ..}, <group>:
{<variable>: <value>, <variable>: <value>, ..}, ..
}
:return: None
"""
project = invsrc.source_project
repo_path = urlsplit(project.scm_url).path
filepath = os.path.join(repo_path, invsrc.source_path)
# print(f"change_source_vars_and_update: {project=} {repo_path=} {filepath=}")
with open(filepath, "w") as fp:
for group, variables in group_vars.items():
fp.write(f"[{group}:vars]\n")
for name, value in variables.items():
fp.write(f"{name}={value}\n")
subprocess.run('git add .; git commit -m "Update variables in invsrc.source_path"', cwd=repo_path, shell=True)
# Update the project to sync the changed repo contents.
project.update()
wait_for_update(project)
# Update the inventory from the changed source.
invsrc.update()
wait_for_update(invsrc)
@pytest.fixture
def organization():
name = f"{NAME_PREFIX}-org"
instance = create_new_by_name(Organization, name=name, description=f"Description for {name}")
yield instance
instance.delete()
@pytest.fixture
def project(organization, live_tmp_folder):
name = f"{NAME_PREFIX}-project"
instance = create_new_by_name(
Project,
name=name,
description=f"Description for {name}",
organization=organization,
scm_url=f"file://{live_tmp_folder}/{GIT_REPO_FOLDER}",
scm_type="git",
)
yield instance
instance.delete()
@pytest.fixture
def inventory(organization):
name = f"{NAME_PREFIX}-inventory"
instance = create_new_by_name(
Inventory,
name=name,
description=f"Description for {name}",
organization=organization,
)
yield instance
instance.delete()
@pytest.fixture
def inventory_source(inventory, project):
name = f"{NAME_PREFIX}-invsrc"
inv_src = InventorySource(
name=name,
source_project=project,
source="scm",
source_path="inventory_var_deleted_in_source.ini",
inventory=inventory,
overwrite_vars=True,
)
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
inv_src.save()
yield inv_src
inv_src.delete()
@pytest.fixture
def inventory_source_factory(inventory, project):
"""
Use this fixture if you want to use multiple inventory sources for the same
inventory in your test.
"""
# https://docs.pytest.org/en/stable/how-to/fixtures.html#factories-as-fixtures
created = []
# repo_path = f"{live_tmp_folder}/{GIT_REPO_FOLDER}"
def _factory(inventory_file, name):
# Make sure the inventory file exists before the inventory source
# instance is created.
#
# Note: The current implementation of the inventory source object allows
# to create an instance even when the inventory source file does not
# exist. If this behaviour changes, uncomment the following code block
# and add the fixture `live_tmp_folder` to the factory function
# signature.
#
# inventory_file_path = os.path.join(repo_path, inventory_file) if not
# os.path.isfile(inventory_file_path): with open(inventory_file_path,
# "w") as fp: pass subprocess.run(f'git add .; git commit -m "Create
# {inventory_file_path}"', cwd=repo_path, shell=True)
#
# Create the inventory source instance.
name = f"{NAME_PREFIX}-invsrc-{name}"
inv_src = InventorySource(
name=name,
source_project=project,
source="scm",
source_path=inventory_file,
inventory=inventory,
overwrite_vars=True,
)
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
inv_src.save()
return inv_src
yield _factory
for instance in created:
instance.delete()
def test_inventory_var_deleted_in_source(inventory, inventory_source):
"""
Verify that a variable which is deleted from its (git-)source between two
updates is also deleted from the inventory.
Verifies https://issues.redhat.com/browse/AAP-17690
"""
inventory_source.update()
wait_for_update(inventory_source)
assert {"a": "value_a", "b": "value_b"} == Inventory.objects.get(name=inventory.name).variables_dict
# Remove variable `a` from source and verify that it is also removed from
# the inventory variables.
change_source_vars_and_update(inventory_source, {"all": {"b": "value_b"}})
assert {"b": "value_b"} == Inventory.objects.get(name=inventory.name).variables_dict
def test_inventory_vars_with_multiple_sources(inventory, inventory_source_factory):
"""
Verify a sequence of updates from various sources with changing content.
"""
invsrc_a = inventory_source_factory("invsrc_a.ini", "A")
invsrc_b = inventory_source_factory("invsrc_b.ini", "B")
invsrc_c = inventory_source_factory("invsrc_c.ini", "C")
change_source_vars_and_update(invsrc_a, {"all": {"x": "x_from_a", "y": "y_from_a"}})
assert {"x": "x_from_a", "y": "y_from_a"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_b, {"all": {"x": "x_from_b", "y": "y_from_b", "z": "z_from_b"}})
assert {"x": "x_from_b", "y": "y_from_b", "z": "z_from_b"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_c, {"all": {"x": "x_from_c", "z": "z_from_c"}})
assert {"x": "x_from_c", "y": "y_from_b", "z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_b, {"all": {}})
assert {"x": "x_from_c", "y": "y_from_a", "z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_c, {"all": {"z": "z_from_c"}})
assert {"x": "x_from_a", "y": "y_from_a", "z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_a, {"all": {}})
assert {"z": "z_from_c"} == Inventory.objects.get(name=inventory.name).variables_dict
change_source_vars_and_update(invsrc_c, {"all": {}})
assert {} == Inventory.objects.get(name=inventory.name).variables_dict

View File

@@ -1,40 +0,0 @@
import time
from awx.api.versioning import reverse
from awx.main.models import Job
from awx.main.tests.live.tests.conftest import wait_for_events
def test_cancel_and_delete_job(live_tmp_folder, run_job_from_playbook, post, delete, admin):
res = run_job_from_playbook('test_cancel_and_delete_job', 'sleep.yml', scm_url=f'file://{live_tmp_folder}/debug', wait=False)
job = res['job']
assert job.status == 'pending'
# Wait for first event so that we can be sure the job is in-progress first
start = time.time()
timeout = 10.0
while not job.job_events.exists():
time.sleep(0.2)
if time.time() - start > timeout:
assert False, f'Did not receive first event for job_id={job.id} in {timeout} seconds'
# Now cancel the job
url = reverse("api:job_cancel", kwargs={'pk': job.pk})
post(url, user=admin, expect=202)
# Job status should change to expected status before infinity
start = time.time()
timeout = 5.0
job.refresh_from_db()
while job.status != 'canceled':
time.sleep(0.05)
job.refresh_from_db(fields=['status'])
if time.time() - start > timeout:
assert False, f'job_id={job.id} still status={job.status} after {timeout} seconds'
wait_for_events(job)
url = reverse("api:job_detail", kwargs={'pk': job.pk})
delete(url, user=admin, expect=204)
assert not Job.objects.filter(id=job.id).exists()

View File

@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
import json
import os
import time
import pytest
from awx.main.models import (
@@ -13,8 +15,6 @@ from django.utils.timezone import now
from datetime import timedelta
import time
@pytest.fixture
def ref_time():
@@ -33,23 +33,15 @@ def hosts(ref_time):
def test_start_job_fact_cache(hosts, tmpdir):
# Create artifacts dir inside tmpdir
artifacts_dir = tmpdir.mkdir("artifacts")
# Assign a mock inventory ID
inventory_id = 42
# Call the function WITHOUT log_data — the decorator handles it
start_fact_cache(hosts, artifacts_dir=str(artifacts_dir), timeout=0, inventory_id=inventory_id)
# Fact files are written into artifacts_dir/fact_cache/
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
fact_cache = os.path.join(tmpdir, 'facts')
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=0)
for host in hosts:
filepath = os.path.join(fact_cache_dir, host.name)
filepath = os.path.join(fact_cache, host.name)
assert os.path.exists(filepath)
with open(filepath, 'r', encoding='utf-8') as f:
assert json.load(f) == host.ansible_facts
with open(filepath, 'r') as f:
assert f.read() == json.dumps(host.ansible_facts)
assert os.path.getmtime(filepath) <= last_modified
def test_fact_cache_with_invalid_path_traversal(tmpdir):
@@ -59,84 +51,64 @@ def test_fact_cache_with_invalid_path_traversal(tmpdir):
ansible_facts={"a": 1, "b": 2},
),
]
artifacts_dir = tmpdir.mkdir("artifacts")
inventory_id = 42
start_fact_cache(hosts, artifacts_dir=str(artifacts_dir), timeout=0, inventory_id=inventory_id)
# Fact cache directory (safe location)
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
# The bad host name should not produce a file
assert not os.path.exists(os.path.join(fact_cache_dir, '../foo'))
# Make sure the fact_cache dir exists and is still empty
assert os.listdir(fact_cache_dir) == []
fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=0)
# a file called "foo" should _not_ be written outside the facts dir
assert os.listdir(os.path.join(fact_cache, '..')) == ['facts']
def test_start_job_fact_cache_past_timeout(hosts, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=2)
# the hosts fixture was modified 5s ago, which is more than 2s
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=2)
assert last_modified is None
for host in hosts:
assert not os.path.exists(os.path.join(fact_cache, host.name))
ret = start_fact_cache(hosts, fact_cache, timeout=2)
assert ret is None
def test_start_job_fact_cache_within_timeout(hosts, tmpdir):
artifacts_dir = tmpdir.mkdir("artifacts")
# The hosts fixture was modified 5s ago, which is less than 7s
start_fact_cache(hosts, str(artifacts_dir), timeout=7)
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
for host in hosts:
filepath = os.path.join(fact_cache_dir, host.name)
assert os.path.exists(filepath)
with open(filepath, 'r') as f:
assert json.load(f) == host.ansible_facts
def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=0)
# the hosts fixture was modified 5s ago, which is less than 7s
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=7)
assert last_modified
bulk_update = mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
for host in hosts:
assert os.path.exists(os.path.join(fact_cache, host.name))
# Mock the os.path.exists behavior for host deletion
# Let's assume the fact file for hosts[1] is missing.
mocker.patch('os.path.exists', side_effect=lambda path: hosts[1].name not in path)
# Simulate one host's fact file getting deleted manually
host_to_delete_filepath = os.path.join(fact_cache, hosts[1].name)
def test_finish_job_fact_cache_with_existing_data(hosts, mocker, tmpdir, ref_time):
fact_cache = os.path.join(tmpdir, 'facts')
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=0)
# Simulate the file being removed by checking existence first, to avoid FileNotFoundError
if os.path.exists(host_to_delete_filepath):
os.remove(host_to_delete_filepath)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
finish_fact_cache(fact_cache)
ansible_facts_new = {"foo": "bar"}
filepath = os.path.join(fact_cache, hosts[1].name)
with open(filepath, 'w') as f:
f.write(json.dumps(ansible_facts_new))
f.flush()
# I feel kind of gross about calling `os.utime` by hand, but I noticed
# that in our container-based dev environment, the resolution for
# `os.stat()` after a file write was over a second, and I don't want to put
# a sleep() in this test
new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time))
# Simulate side effects that would normally be applied during bulk update
hosts[1].ansible_facts = {}
hosts[1].ansible_facts_modified = now()
finish_fact_cache(hosts, fact_cache, last_modified)
# Verify facts are preserved for hosts with valid cache files
for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time
# Verify facts were cleared for host with deleted cache file
assert hosts[1].ansible_facts == {}
assert hosts[1].ansible_facts == ansible_facts_new
assert hosts[1].ansible_facts_modified > ref_time
# Current implementation skips the call entirely if hosts_to_update == []
bulk_update.assert_not_called()
bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])
def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=0)
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=0)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
@@ -148,6 +120,23 @@ def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time))
finish_fact_cache(fact_cache)
finish_fact_cache(hosts, fact_cache, last_modified)
bulk_update.assert_not_called()
def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts')
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=0)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
os.remove(os.path.join(fact_cache, hosts[1].name))
finish_fact_cache(hosts, fact_cache, last_modified)
for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time
assert hosts[1].ansible_facts == {}
assert hosts[1].ansible_facts_modified > ref_time
bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])

View File

@@ -561,7 +561,7 @@ class TestBFSNodesToRun:
assert set([nodes[1], nodes[2]]) == set(g.bfs_nodes_to_run())
@pytest.mark.xfail(reason="Run manually to re-generate doc images")
@pytest.mark.skip(reason="Run manually to re-generate doc images")
class TestDocsExample:
@pytest.fixture
def complex_dag(self, wf_node_generator):

View File

@@ -32,140 +32,112 @@ def private_data_dir():
shutil.rmtree(private_data, True)
@mock.patch('awx.main.tasks.facts.update_hosts')
@mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
# Create mocked inventory and host queryset
inventory = mock.MagicMock(spec=Inventory, pk=1)
host1 = mock.MagicMock(spec=Host, id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
host2 = mock.MagicMock(spec=Host, id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
# Mock hosts queryset
hosts = [host1, host2]
qs_hosts = mock.MagicMock(spec=QuerySet)
def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, update_hosts, private_data_dir, execution_environment):
# creates inventory_object with two hosts
inventory = Inventory(pk=1)
mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory)
mock_inventory._state = mock.MagicMock()
qs_hosts = QuerySet()
hosts = [
Host(id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=mock_inventory),
Host(id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=mock_inventory),
]
qs_hosts._result_cache = hosts
qs_hosts.only.return_value = hosts
qs_hosts.count.side_effect = lambda: len(qs_hosts._result_cache)
inventory.hosts = qs_hosts
qs_hosts.only = mock.MagicMock(return_value=hosts)
mock_inventory.hosts = qs_hosts
assert mock_inventory.hosts.count() == 2
# Create mocked job object
org = mock.MagicMock(spec=Organization, pk=1)
proj = mock.MagicMock(spec=Project, pk=1, organization=org)
job = mock.MagicMock(
spec=Job,
use_fact_cache=True,
project=proj,
organization=org,
job_slice_number=1,
job_slice_count=1,
inventory=inventory,
execution_environment=execution_environment,
)
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
# creates job object with fact_cache enabled
org = Organization(pk=1)
proj = Project(pk=1, organization=org)
job = mock.MagicMock(spec=Job, use_fact_cache=True, project=proj, organization=org, job_slice_number=1, job_slice_count=1)
job.inventory = mock_inventory
job.execution_environment = execution_environment
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job) # to run original method
job.job_env.get = mock.MagicMock(return_value=private_data_dir)
# Mock RunJob task
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False
# creates the task object with job object as instance
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False # defines timeout to false
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
# Run pre_run_hook
# run pre_run_hook
task.facts_write_time = task.pre_run_hook(job, private_data_dir)
# Add a third mocked host
host3 = mock.MagicMock(spec=Host, id=3, name='host3', ansible_facts={"added": True}, ansible_facts_modified=now(), inventory=inventory)
qs_hosts._result_cache.append(host3)
assert inventory.hosts.count() == 3
# updates inventory with one more host
hosts.append(Host(id=3, name='host3', ansible_facts={"added": True}, ansible_facts_modified=now(), inventory=mock_inventory))
assert mock_inventory.hosts.count() == 3
# Run post_run_hook
# run post_run_hook
task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True)
task.post_run_hook(job, "success")
# Verify final host facts
assert qs_hosts._result_cache[2].ansible_facts == {"added": True}
assert mock_inventory.hosts[2].ansible_facts == {"added": True}
@mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
@mock.patch('awx.main.tasks.facts.update_hosts')
@mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
# Fully mocked inventory
mock_inventory = mock.MagicMock(spec=Inventory)
def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, update_hosts, private_data_dir, execution_environment):
# creates inventory_object with two hosts
inventory = Inventory(pk=1)
mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory)
mock_inventory._state = mock.MagicMock()
qs_hosts = QuerySet()
hosts = [Host(id=num, name=f'host{num}', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=mock_inventory) for num in range(999)]
# Create 999 mocked Host instances
hosts = []
for i in range(999):
host = mock.MagicMock(spec=Host)
host.id = i
host.name = f'host{i}'
host.ansible_facts = {"a": 1, "b": 2}
host.ansible_facts_modified = now()
host.inventory = mock_inventory
hosts.append(host)
qs_hosts._result_cache = hosts
qs_hosts.only = mock.MagicMock(return_value=hosts)
mock_inventory.hosts = qs_hosts
assert mock_inventory.hosts.count() == 999
# Mock inventory.hosts behavior
mock_qs_hosts = mock.MagicMock()
mock_qs_hosts.only.return_value = hosts
mock_qs_hosts.count.return_value = 999
mock_inventory.hosts = mock_qs_hosts
# Mock Organization and Project
org = mock.MagicMock(spec=Organization)
proj = mock.MagicMock(spec=Project)
proj.organization = org
# Mock job object
job = mock.MagicMock(spec=Job)
job.use_fact_cache = True
job.project = proj
job.organization = org
job.job_slice_number = 1
job.job_slice_count = 3
job.execution_environment = execution_environment
# creates job object with fact_cache enabled
org = Organization(pk=1)
proj = Project(pk=1, organization=org)
job = mock.MagicMock(spec=Job, use_fact_cache=True, project=proj, organization=org, job_slice_number=1, job_slice_count=3)
job.inventory = mock_inventory
job.job_env.get.return_value = private_data_dir
job.execution_environment = execution_environment
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job) # to run original method
job.job_env.get = mock.MagicMock(return_value=private_data_dir)
# Bind actual method for host filtering
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
# Mock task instance
# creates the task object with job object as instance
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
# Call pre_run_hook
# run pre_run_hook
task.facts_write_time = task.pre_run_hook(job, private_data_dir)
# Simulate one host deletion
hosts.pop(1)
mock_qs_hosts.count.return_value = 998
assert mock_inventory.hosts.count() == 998
# Call post_run_hook
# run post_run_hook
task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True)
task.post_run_hook(job, "success")
# Assert that ansible_facts were preserved
for host in hosts:
assert host.ansible_facts == {"a": 1, "b": 2}
# Add expected failure cases
failures = []
for host in hosts:
try:
assert host.ansible_facts == {"a": 1, "b": 2, "unexpected_key": "bad"}
except AssertionError:
failures.append(f"Host named {host.name} has facts {host.ansible_facts}")
failures.append("Host named {} has facts {}".format(host.name, host.ansible_facts))
assert len(failures) > 0, f"Failures occurred for the following hosts: {failures}"
@mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
@mock.patch('awx.main.tasks.facts.update_hosts')
@mock.patch('awx.main.tasks.facts.settings')
def test_invalid_host_facts(mock_facts_settings, bulk_update_sorted_by_id, private_data_dir, execution_environment):
def test_invalid_host_facts(mock_facts_settings, update_hosts, private_data_dir, execution_environment):
inventory = Inventory(pk=1)
mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory)
mock_inventory._state = mock.MagicMock()
@@ -183,7 +155,7 @@ def test_invalid_host_facts(mock_facts_settings, bulk_update_sorted_by_id, priva
failures.append(host.name)
mock_facts_settings.SOME_SETTING = True
bulk_update_sorted_by_id(Host, mock_inventory.hosts, fields=['ansible_facts'])
update_hosts(mock_inventory.hosts)
with pytest.raises(pytest.fail.Exception):
if failures:

View File

@@ -50,7 +50,7 @@ def test_outer_inner_signal_handling():
@with_signal_handling
def f1():
assert signal_callback() is False
signal_state.set_signal_flag(for_signal=signal.SIGTERM)
signal_state.set_sigterm_flag()
assert signal_callback()
f2()
@@ -74,7 +74,7 @@ def test_inner_outer_signal_handling():
@with_signal_handling
def f2():
assert signal_callback() is False
signal_state.set_signal_flag(for_signal=signal.SIGINT)
signal_state.set_sigint_flag()
assert signal_callback()
@with_signal_handling

View File

@@ -107,7 +107,7 @@ def job():
@pytest.fixture
def adhoc_job():
return AdHocCommand(pk=1, id=1, inventory=Inventory(), status='waiting')
return AdHocCommand(pk=1, id=1, inventory=Inventory())
@pytest.fixture
@@ -472,7 +472,7 @@ class TestGenericRun:
task.model.objects.get = mock.Mock(return_value=job)
task.build_private_data_files = mock.Mock(side_effect=OSError())
with mock.patch('awx.main.tasks.jobs.shutil.copytree'), mock.patch('awx.main.tasks.jobs.evaluate_policy'):
with mock.patch('awx.main.tasks.jobs.shutil.copytree'):
with pytest.raises(Exception):
task.run(1)
@@ -481,6 +481,26 @@ class TestGenericRun:
assert update_model_call['status'] == 'error'
assert update_model_call['emitted_events'] == 0
def test_cancel_flag(self, job, update_model_wrapper, execution_environment, mock_me, mock_create_partition):
job.status = 'running'
job.cancel_flag = True
job.websocket_emit_status = mock.Mock()
job.send_notification_templates = mock.Mock()
job.execution_environment = execution_environment
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(wraps=update_model_wrapper)
task.model.objects.get = mock.Mock(return_value=job)
task.build_private_data_files = mock.Mock()
with mock.patch('awx.main.tasks.jobs.shutil.copytree'):
with pytest.raises(Exception):
task.run(1)
for c in [mock.call(1, start_args='', status='canceled')]:
assert c in task.update_model.call_args_list
def test_event_count(self, mock_me):
task = jobs.RunJob()
task.runner_callback.dispatcher = mock.MagicMock()
@@ -569,8 +589,6 @@ class TestAdhocRun(TestJobExecution):
adhoc_job.send_notification_templates = mock.Mock()
task = jobs.RunAdHocCommand()
adhoc_job.status = 'running' # to bypass status flip
task.instance = adhoc_job # to bypass fetch
task.update_model = mock.Mock(wraps=adhoc_update_model_wrapper)
task.model.objects.get = mock.Mock(return_value=adhoc_job)
task.build_inventory = mock.Mock()

View File

@@ -1,110 +0,0 @@
"""
Test utility functions and classes for inventory variable handling.
"""
import pytest
from awx.main.utils.inventory_vars import InventoryVariable
from awx.main.utils.inventory_vars import InventoryGroupVariables
def test_inventory_variable_update_basic():
"""Test basic functionality of an inventory variable."""
x = InventoryVariable("x")
assert x.has_no_source
x.update(1, 101)
assert str(x) == "1"
x.update(2, 102)
assert str(x) == "2"
x.update(3, 103)
assert str(x) == "3"
x.delete(102)
assert str(x) == "3"
x.delete(103)
assert str(x) == "1"
x.delete(101)
assert x.value is None
assert x.has_no_source
@pytest.mark.parametrize(
"updates", # (<source_id>, <value>, <expected_value>)
[
((101, 1, 1),),
((101, 1, 1), (101, None, None)),
((101, 1, 1), (102, 2, 2), (102, None, 1)),
((101, 1, 1), (102, 2, 2), (101, None, 2), (102, None, None)),
(
(101, 0, 0),
(101, 1, 1),
(102, 2, 2),
(103, 3, 3),
(102, None, 3),
(103, None, 1),
(101, None, None),
),
],
)
def test_inventory_variable_update(updates: tuple[int, int | None, int | None]):
"""
Test if the variable value is set correctly on a sequence of updates.
For this test, the value `None` implies the deletion of the source.
"""
x = InventoryVariable("x")
for src_id, value, expected_value in updates:
if value is None:
x.delete(src_id)
else:
x.update(value, src_id)
assert x.value == expected_value
def test_inventory_group_variables_update_basic():
"""Test basic functionality of an inventory variables update."""
vars = InventoryGroupVariables(1)
vars.update_from_src({"x": 1, "y": 2}, 101)
assert vars == {"x": 1, "y": 2}
@pytest.mark.parametrize(
"updates", # (<source_id>, <vars>: dict, <expected_vars>: dict)
[
((101, {"x": 1, "y": 1}, {"x": 1, "y": 1}),),
(
(101, {"x": 1, "y": 1}, {"x": 1, "y": 1}),
(102, {}, {"x": 1, "y": 1}),
),
(
(101, {"x": 1, "y": 1}, {"x": 1, "y": 1}),
(102, {"x": 2}, {"x": 2, "y": 1}),
),
(
(101, {"x": 1, "y": 1}, {"x": 1, "y": 1}),
(102, {"x": 2, "y": 2}, {"x": 2, "y": 2}),
),
(
(101, {"x": 1, "y": 1}, {"x": 1, "y": 1}),
(102, {"x": 2, "z": 2}, {"x": 2, "y": 1, "z": 2}),
),
(
(101, {"x": 1, "y": 1}, {"x": 1, "y": 1}),
(102, {"x": 2, "z": 2}, {"x": 2, "y": 1, "z": 2}),
(102, {}, {"x": 1, "y": 1}),
),
(
(101, {"x": 1, "y": 1}, {"x": 1, "y": 1}),
(102, {"x": 2, "z": 2}, {"x": 2, "y": 1, "z": 2}),
(103, {"x": 3}, {"x": 3, "y": 1, "z": 2}),
(101, {}, {"x": 3, "z": 2}),
),
],
)
def test_inventory_group_variables_update(updates: tuple[int, int | None, int | None]):
"""
Test if the group vars are set correctly on various update sequences.
"""
groupvars = InventoryGroupVariables(2)
for src_id, vars, expected_vars in updates:
groupvars.update_from_src(vars, src_id)
assert groupvars == expected_vars

View File

@@ -1,37 +0,0 @@
import json
from http import HTTPStatus
from unittest.mock import patch
from requests import Response
from awx.main.utils.licensing import Licenser
def test_rhsm_licensing():
def mocked_requests_get(*args, **kwargs):
assert kwargs['verify'] == True
response = Response()
subs = json.dumps({'body': []})
response.status_code = HTTPStatus.OK
response._content = bytes(subs, 'utf-8')
return response
licenser = Licenser()
with patch('awx.main.utils.analytics_proxy.OIDCClient.make_request', new=mocked_requests_get):
subs = licenser.get_rhsm_subs('localhost', 'admin', 'admin')
assert subs == []
def test_satellite_licensing():
def mocked_requests_get(*args, **kwargs):
assert kwargs['verify'] == True
response = Response()
subs = json.dumps({'results': []})
response.status_code = HTTPStatus.OK
response._content = bytes(subs, 'utf-8')
return response
licenser = Licenser()
with patch('requests.get', new=mocked_requests_get):
subs = licenser.get_satellite_subs('localhost', 'admin', 'admin')
assert subs == []

View File

@@ -23,7 +23,7 @@ class TokenError(requests.RequestException):
try:
client = OIDCClient(...)
client.make_request(...)
except TokenError as e:
except TokenGenerationError as e:
print(f"Token generation failed due to {e.__cause__}")
except requests.RequestException:
print("API request failed)
@@ -102,15 +102,13 @@ class OIDCClient:
self,
client_id: str,
client_secret: str,
token_url: str = DEFAULT_OIDC_TOKEN_ENDPOINT,
scopes: list[str] = None,
token_url: str,
scopes: list[str],
base_url: str = '',
) -> None:
self.client_id: str = client_id
self.client_secret: str = client_secret
self.token_url: str = token_url
if scopes is None:
scopes = ['api.console']
self.scopes = scopes
self.base_url: str = base_url
self.token: Optional[Token] = None

View File

@@ -1,34 +1,10 @@
# Copyright (c) 2017 Ansible by Red Hat
# All Rights Reserved.
from awx.settings.application_name import set_application_name
from awx.settings.application_name import set_application_name
from django.conf import settings
def set_connection_name(function):
set_application_name(settings.DATABASES, settings.CLUSTER_HOST_ID, function=function)
def bulk_update_sorted_by_id(model, objects, fields, batch_size=1000):
"""
Perform a sorted bulk update on model instances to avoid database deadlocks.
This function was introduced to prevent deadlocks observed in the AWX Controller
when concurrent jobs attempt to update different fields on the same `main_hosts` table.
Specifically, deadlocks occurred when one process updated `last_job_id` while another
simultaneously updated `ansible_facts`.
By sorting updates ID, we ensure a consistent update order,
which helps avoid the row-level locking contention that can lead to deadlocks
in PostgreSQL when multiple processes are involved.
Returns:
int: The number of rows affected by the update.
"""
objects = [obj for obj in objects if obj.id is not None]
if not objects:
return 0 # Return 0 when nothing is updated
sorted_objects = sorted(objects, key=lambda obj: obj.id)
return model.objects.bulk_update(sorted_objects, fields, batch_size=batch_size)

View File

@@ -6,7 +6,7 @@ import urllib.parse as urlparse
from django.conf import settings
from awx.main.utils.reload import supervisor_service_command
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch.publish import task
def construct_rsyslog_conf_template(settings=settings):
@@ -139,7 +139,7 @@ def construct_rsyslog_conf_template(settings=settings):
return tmpl
@task_awx(queue='rsyslog_configurer')
@task(queue='rsyslog_configurer')
def reconfigure_rsyslog():
tmpl = construct_rsyslog_conf_template()
# Write config to a temp file then move it to preserve atomicity

View File

@@ -1,277 +0,0 @@
import logging
from typing import TypeAlias, Any
from awx.main.models import InventoryGroupVariablesWithHistory
var_value: TypeAlias = Any
update_queue: TypeAlias = list[tuple[int, var_value]]
logger = logging.getLogger('awx.api.inventory_import')
class InventoryVariable:
"""
Represents an inventory variable.
This class keeps track of the variable updates from different inventory
sources.
"""
def __init__(self, name: str) -> None:
"""
:param str name: The variable's name.
:return: None
"""
self.name = name
self._update_queue: update_queue = []
"""
A queue representing updates from inventory sources in the sequence of
occurrence.
The queue is realized as a list of two-tuples containing variable values
and their originating inventory source. The last item of the list is
considered the top of the queue, and holds the current value of the
variable.
"""
def reset(self) -> None:
"""Reset the variable by deleting its history."""
self._update_queue = []
def load(self, updates: update_queue) -> "InventoryVariable":
"""Load internal state from a list."""
self._update_queue = updates
return self
def dump(self) -> update_queue:
"""Save internal state to a list."""
return self._update_queue
def update(self, value: var_value, invsrc_id: int) -> None:
"""
Update the variable with a new value from an inventory source.
Updating means that this source is moved to the top of the queue
and `value` becomes the new current value.
:param value: The new value of the variable.
:param int invsrc_id: The inventory source of the new variable value.
:return: None
"""
logger.debug(f"InventoryVariable().update({value}, {invsrc_id}):")
# Move this source to the front of the queue by first deleting a
# possibly existing entry, and then add the new entry to the front.
self.delete(invsrc_id)
self._update_queue.append((invsrc_id, value))
def delete(self, invsrc_id: int) -> None:
"""
Delete an inventory source from the variable.
:param int invsrc_id: The inventory source id.
:return: None
"""
data_index = self._get_invsrc_index(invsrc_id)
# Remove last update from this source, if there was any.
if data_index is not None:
value = self._update_queue.pop(data_index)[1]
logger.debug(f"InventoryVariable().delete({invsrc_id}): {data_index=} {value=}")
def _get_invsrc_index(self, invsrc_id: int) -> int | None:
"""Return the inventory source's position in the queue, or `None`."""
for i, entry in enumerate(self._update_queue):
if entry[0] == invsrc_id:
return i
return None
def _get_current_value(self) -> var_value:
"""
Return the current value of the variable, or None if the variable has no
history.
"""
return self._update_queue[-1][1] if self._update_queue else None
@property
def value(self) -> var_value:
"""Read the current value of the variable."""
return self._get_current_value()
@property
def has_no_source(self) -> bool:
"""True, if the variable is orphan, i.e. no source contains this var anymore."""
return not self._update_queue
def __str__(self):
"""Return the string representation of the current value."""
return str(self.value or "")
class InventoryGroupVariables(dict):
"""
Represent all inventory variables from one group.
This dict contains all variables of a inventory group and their current
value under consideration of the inventory source update history.
Note that variables values cannot be `None`, use the empty string to
indicate that a variable holds no value. See also `InventoryVariable`.
"""
def __init__(self, id: int) -> None:
"""
:param int id: The id of the group object.
:return: None
"""
super().__init__()
self.id = id
# In _vars we keep all sources for a given variable. This enables us to
# find the current value for a variable, which is the value from the
# latest update which defined this variable.
self._vars: dict[str, InventoryVariable] = {}
def _sync_vars(self) -> None:
"""
Copy the current values of all variables into the internal dict.
Call this everytime the `_vars` structure has been modified.
"""
for name, inv_var in self._vars.items():
self[name] = inv_var.value
def load_state(self, state: dict[str, update_queue]) -> "InventoryGroupVariables":
"""Load internal state from a dict."""
for name, updates in state.items():
self._vars[name] = InventoryVariable(name).load(updates)
self._sync_vars()
return self
def save_state(self) -> dict[str, update_queue]:
"""Return internal state as a dict."""
state = {}
for name, inv_var in self._vars.items():
state[name] = inv_var.dump()
return state
def update_from_src(
self,
new_vars: dict[str, var_value],
source_id: int,
overwrite_vars: bool = True,
reset: bool = False,
) -> None:
"""
Update with variables from an inventory source.
Delete all variables for this source which are not in the update vars.
:param dict new_vars: The variables from the inventory source.
:param int invsrc_id: The id of the inventory source for this update.
:param bool overwrite_vars: If `True`, delete this source's history
entry for variables which are not in this update. If `False`, keep
the old updates in the history for such variables. Default is
`True`.
:param bool reset: If `True`, delete the update history for all existing
variables before updating the new vars. Therewith making this update
overwrite all history. Default is `False`.
:return: None
"""
logger.debug(f"InventoryGroupVariables({self.id}).update_from_src({new_vars=}, {source_id=}, {overwrite_vars=}, {reset=}): {self=}")
# Create variables which are newly introduced by this source.
for name in new_vars:
if name not in self._vars:
self._vars[name] = InventoryVariable(name)
# Combine the names of the existing vars and the new vars from this update.
all_var_names = list(set(list(self.keys()) + list(new_vars.keys())))
# In reset-mode, delete all existing vars and their history before
# updating.
if reset:
for name in all_var_names:
self._vars[name].reset()
# Go through all variables (the existing ones, and the ones added by
# this update), delete this source from variables which are not in this
# update, and update the value of variables which are part of this
# update.
for name in all_var_names:
# Update or delete source from var (if name not in vars).
if name in new_vars:
self._vars[name].update(new_vars[name], source_id)
elif overwrite_vars:
self._vars[name].delete(source_id)
# Delete vars which have no source anymore.
if self._vars[name].has_no_source:
del self._vars[name]
del self[name]
# After the update, refresh the internal dict with the possibly changed
# current values.
self._sync_vars()
logger.debug(f"InventoryGroupVariables({self.id}).update_from_src(): {self=}")
def update_group_variables(
group_id: int | None,
newvars: dict,
dbvars: dict | None,
invsrc_id: int,
inventory_id: int,
overwrite_vars: bool = True,
reset: bool = False,
) -> dict[str, var_value]:
"""
Update the inventory variables of one group.
Merge the new variables into the existing group variables.
The update can be triggered either by an inventory update via API, or via a
manual edit of the variables field in the awx inventory form.
TODO: Can we get rid of the dbvars? This is only needed because the new
update-var mechanism needs to be properly initialized if the db already
contains some variables.
:param int group_id: The inventory group id (pk). For the 'all'-group use
`None`, because this group is not an actual `Group` object in the
database.
:param dict newvars: The variables contained in this update.
:param dict dbvars: The variables which are already stored in the database
for this inventory and this group. Can be `None`.
:param int invsrc_id: The id of the inventory source. Usually this is the
database primary key of the inventory source object, but there is one
special id -1 which is used for the initial update from the database and
for manual updates via the GUI.
:param int inventory_id: The id of the inventory on which this update is
applied.
:param bool overwrite_vars: If `True`, delete variables which were merged
from the same source in a previous update, but are no longer contained
in that source. If `False`, such variables would not be removed from the
group. Default is `True`.
:param bool reset: If `True`, delete all variables from previous updates,
therewith making this update overwrite all history. Default is `False`.
:return: The variables and their current values as a dict.
:rtype: dict
"""
inv_group_vars = InventoryGroupVariables(group_id)
# Restore the existing variables state.
try:
# Get the object for this group from the database.
model = InventoryGroupVariablesWithHistory.objects.get(inventory_id=inventory_id, group_id=group_id)
except InventoryGroupVariablesWithHistory.DoesNotExist:
# If no previous state exists, create a new database object, and
# initialize it with the current group variables.
model = InventoryGroupVariablesWithHistory(inventory_id=inventory_id, group_id=group_id)
if dbvars:
inv_group_vars.update_from_src(dbvars, -1) # Assume -1 as inv_source_id for existing vars.
else:
# Load the group variables state from the database object.
inv_group_vars.load_state(model.variables)
#
logger.debug(f"update_group_variables: before update_from_src {model.variables=}")
# Apply the new inventory update onto the group variables.
inv_group_vars.update_from_src(newvars, invsrc_id, overwrite_vars, reset)
# Save the new variables state.
model.variables = inv_group_vars.save_state()
model.save()
logger.debug(f"update_group_variables: after update_from_src {model.variables=}")
logger.debug(f"update_group_variables({group_id=}, {newvars}): {inv_group_vars}")
return inv_group_vars

View File

@@ -38,7 +38,6 @@ from django.utils.translation import gettext_lazy as _
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name
from awx.main.constants import SUBSCRIPTION_USAGE_MODEL_UNIQUE_HOSTS
from awx.main.utils.analytics_proxy import OIDCClient
MAX_INSTANCES = 9999999
@@ -229,47 +228,37 @@ class Licenser(object):
host = getattr(settings, 'REDHAT_CANDLEPIN_HOST', None)
if not user:
raise ValueError('subscriptions_client_id is required')
raise ValueError('subscriptions_username is required')
if not pw:
raise ValueError('subscriptions_client_secret is required')
raise ValueError('subscriptions_password is required')
if host and user and pw:
if 'subscription.rhsm.redhat.com' in host:
json = self.get_rhsm_subs(settings.SUBSCRIPTIONS_RHSM_URL, user, pw)
json = self.get_rhsm_subs(host, user, pw)
else:
json = self.get_satellite_subs(host, user, pw)
return self.generate_license_options_from_entitlements(json)
return []
def get_rhsm_subs(self, host, client_id, client_secret):
def get_rhsm_subs(self, host, user, pw):
verify = getattr(settings, 'REDHAT_CANDLEPIN_VERIFY', True)
json = []
try:
client = OIDCClient(client_id, client_secret)
subs = client.make_request(
'GET',
host,
verify=True,
timeout=(5, 20),
)
except requests.RequestException:
logger.warning("Failed to connect to console.redhat.com using Service Account credentials. Falling back to basic auth.")
subs = requests.request(
'GET',
host,
auth=(client_id, client_secret),
verify=True,
timeout=(5, 20),
)
subs = requests.get('/'.join([host, 'subscription/users/{}/owners'.format(user)]), verify=verify, auth=(user, pw))
except requests.exceptions.ConnectionError as error:
raise error
except OSError as error:
raise OSError(
'Unable to open certificate bundle {}. Check that the service is running on Red Hat Enterprise Linux.'.format(verify)
) from error # noqa
subs.raise_for_status()
subs_formatted = []
for sku in subs.json()['body']:
sku_data = {k: v for k, v in sku.items() if k != 'subscriptions'}
for sub in sku['subscriptions']:
sub_data = sku_data.copy()
sub_data['subscriptions'] = sub
subs_formatted.append(sub_data)
return subs_formatted
for sub in subs.json():
resp = requests.get('/'.join([host, 'subscription/owners/{}/pools/?match=*tower*'.format(sub['key'])]), verify=verify, auth=(user, pw))
resp.raise_for_status()
json.extend(resp.json())
return json
def get_satellite_subs(self, host, user, pw):
port = None
@@ -278,7 +267,7 @@ class Licenser(object):
port = str(self.config.get("server", "port"))
except Exception as e:
logger.exception('Unable to read rhsm config to get ca_cert location. {}'.format(str(e)))
verify = True
verify = getattr(settings, 'REDHAT_CANDLEPIN_VERIFY', True)
if port:
host = ':'.join([host, port])
json = []
@@ -325,11 +314,20 @@ class Licenser(object):
return False
return True
def is_appropriate_sub(self, sub):
if sub['activeSubscription'] is False:
return False
# Products that contain Ansible Tower
products = sub.get('providedProducts', [])
if any(product.get('productId') == '480' for product in products):
return True
return False
def generate_license_options_from_entitlements(self, json):
from dateutil.parser import parse
ValidSub = collections.namedtuple(
'ValidSub', 'sku name support_level end_date trial developer_license quantity satellite subscription_id account_number usage'
'ValidSub', 'sku name support_level end_date trial developer_license quantity pool_id satellite subscription_id account_number usage'
)
valid_subs = []
for sub in json:
@@ -337,14 +335,10 @@ class Licenser(object):
if satellite:
is_valid = self.is_appropriate_sat_sub(sub)
else:
# the list of subs from console.redhat.com are already valid based on the query params we provided
is_valid = True
is_valid = self.is_appropriate_sub(sub)
if is_valid:
try:
if satellite:
end_date = parse(sub.get('endDate'))
else:
end_date = parse(sub['subscriptions']['endDate'])
end_date = parse(sub.get('endDate'))
except Exception:
continue
now = datetime.utcnow()
@@ -352,55 +346,44 @@ class Licenser(object):
if end_date < now:
# If the sub has a past end date, skip it
continue
try:
quantity = int(sub['quantity'])
if quantity == -1:
# effectively, unlimited
quantity = MAX_INSTANCES
except Exception:
continue
sku = sub['productId']
trial = sku.startswith('S') # i.e.,, SER/SVC
developer_license = False
support_level = ''
account_number = ''
usage = sub.get('usage', '')
usage = ''
pool_id = sub['id']
subscription_id = sub['subscriptionId']
account_number = sub['accountNumber']
if satellite:
try:
quantity = int(sub['quantity'])
except Exception:
continue
sku = sub['productId']
subscription_id = sub['subscriptionId']
sub_name = sub['productName']
support_level = sub['support_level']
account_number = sub['accountNumber']
usage = sub['usage']
else:
try:
# Determine total quantity based on capacity name
# if capacity name is Nodes, capacity quantity x subscription quantity
# if capacity name is Sockets, capacity quantity / 2 (minimum of 1) x subscription quantity
if sub['capacity']['name'] == "Nodes":
quantity = int(sub['capacity']['quantity']) * int(sub['subscriptions']['quantity'])
elif sub['capacity']['name'] == "Sockets":
quantity = max(int(sub['capacity']['quantity']) / 2, 1) * int(sub['subscriptions']['quantity'])
else:
continue
except Exception:
continue
sku = sub['sku']
sub_name = sub['name']
support_level = sub['serviceLevel']
subscription_id = sub['subscriptions']['number']
if sub.get('name') == 'RHEL Developer':
developer_license = True
if quantity == -1:
# effectively, unlimited
quantity = MAX_INSTANCES
trial = sku.startswith('S') # i.e.,, SER/SVC
for attr in sub.get('productAttributes', []):
if attr.get('name') == 'support_level':
support_level = attr.get('value')
elif attr.get('name') == 'usage':
usage = attr.get('value')
elif attr.get('name') == 'ph_product_name' and attr.get('value') == 'RHEL Developer':
developer_license = True
valid_subs.append(
ValidSub(
sku,
sub_name,
sub['productName'],
support_level,
end_date,
trial,
developer_license,
quantity,
pool_id,
satellite,
subscription_id,
account_number,
@@ -431,6 +414,7 @@ class Licenser(object):
license._attrs['satellite'] = satellite
license._attrs['valid_key'] = True
license.update(license_date=int(sub.end_date.strftime('%s')))
license.update(pool_id=sub.pool_id)
license.update(subscription_id=sub.subscription_id)
license.update(account_number=sub.account_number)
licenses.append(license._attrs.copy())

View File

@@ -422,9 +422,6 @@ DISPATCHER_DB_DOWNTIME_TOLERANCE = 40
# sqlite3 based tests will use this
DISPATCHER_MOCK_PUBLISH = False
# Debugging sockfile for the --status command
DISPATCHERD_DEBUGGING_SOCKFILE = os.path.join(BASE_DIR, 'dispatcherd.sock')
BROKER_URL = 'unix:///var/run/redis/redis.sock'
CELERYBEAT_SCHEDULE = {
'tower_scheduler': {'task': 'awx.main.tasks.system.awx_periodic_scheduler', 'schedule': timedelta(seconds=30), 'options': {'expires': 20}},
@@ -449,17 +446,6 @@ CELERYBEAT_SCHEDULE = {
},
}
DISPATCHER_SCHEDULE = {}
for options in CELERYBEAT_SCHEDULE.values():
new_options = options.copy()
task_name = options['task']
# Handle the only one exception case of the heartbeat which has a new implementation
if task_name == 'awx.main.tasks.system.cluster_node_heartbeat':
task_name = 'awx.main.tasks.system.adispatch_cluster_node_heartbeat'
new_options['task'] = task_name
new_options['schedule'] = options['schedule'].total_seconds()
DISPATCHER_SCHEDULE[task_name] = new_options
# Django Caching Configuration
DJANGO_REDIS_IGNORE_EXCEPTIONS = True
CACHES = {'default': {'BACKEND': 'awx.main.cache.AWXRedisCache', 'LOCATION': 'unix:///var/run/redis/redis.sock?db=1'}}
@@ -809,7 +795,6 @@ LOGGING = {
'social': {'handlers': ['console', 'file', 'tower_warnings'], 'level': 'DEBUG'},
'system_tracking_migrations': {'handlers': ['console', 'file', 'tower_warnings'], 'level': 'DEBUG'},
'rbac_migrations': {'handlers': ['console', 'file', 'tower_warnings'], 'level': 'DEBUG'},
'dispatcherd': {'handlers': ['dispatcher', 'console'], 'level': 'INFO'},
},
}
@@ -979,9 +964,6 @@ CLUSTER_HOST_ID = socket.gethostname()
# - 'unique_managed_hosts': Compliant = automated - deleted hosts (using /api/v2/host_metrics/)
SUBSCRIPTION_USAGE_MODEL = ''
# Default URL and query params for obtaining valid AAP subscriptions
SUBSCRIPTIONS_RHSM_URL = 'https://console.redhat.com/api/rhsm/v2/products?include=providedProducts&oids=480&status=Active'
# Host metrics cleanup - last time of the task/command run
CLEANUP_HOST_METRICS_LAST_TS = None
# Host metrics cleanup - minimal interval between two cleanups in days
@@ -1009,7 +991,7 @@ HOST_METRIC_SUMMARY_TASK_INTERVAL = 7 # days
# projects can take advantage.
METRICS_SERVICE_CALLBACK_RECEIVER = 'callback_receiver'
METRICS_SERVICE_DISPATCHER = 'dispatcherd'
METRICS_SERVICE_DISPATCHER = 'dispatcher'
METRICS_SERVICE_WEBSOCKETS = 'websockets'
METRICS_SUBSYSTEM_CONFIG = {
@@ -1092,27 +1074,8 @@ INDIRECT_HOST_QUERY_FALLBACK_GIVEUP_DAYS = 3
# Older records will be cleaned up
INDIRECT_HOST_AUDIT_RECORD_MAX_AGE_DAYS = 7
OPA_HOST = '' # The hostname used to connect to the OPA server. If empty, policy enforcement will be disabled.
OPA_PORT = 8181 # The port used to connect to the OPA server. Defaults to 8181.
OPA_SSL = False # Enable or disable the use of SSL to connect to the OPA server. Defaults to false.
OPA_AUTH_TYPE = 'None' # The authentication type that will be used to connect to the OPA server: "None", "Token", or "Certificate".
OPA_AUTH_TOKEN = '' # The token for authentication to the OPA server. Required when OPA_AUTH_TYPE is "Token". If an authorization header is defined in OPA_AUTH_CUSTOM_HEADERS, it will be overridden by OPA_AUTH_TOKEN.
OPA_AUTH_CLIENT_CERT = '' # The content of the client certificate file for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".
OPA_AUTH_CLIENT_KEY = '' # The content of the client key for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".
OPA_AUTH_CA_CERT = '' # The content of the CA certificate for mTLS authentication to the OPA server. Required when OPA_AUTH_TYPE is "Certificate".
OPA_AUTH_CUSTOM_HEADERS = {} # Optional custom headers included in requests to the OPA server. Defaults to empty dictionary ({}).
OPA_REQUEST_TIMEOUT = 1.5 # The number of seconds after which the connection to the OPA server will time out. Defaults to 1.5 seconds.
OPA_REQUEST_RETRIES = 2 # The number of retry attempts for connecting to the OPA server. Default is 2.
# feature flags
FLAG_SOURCES = ('flags.sources.SettingsFlagsSource',)
FLAGS = {
'FEATURE_INDIRECT_NODE_COUNTING_ENABLED': [{'condition': 'boolean', 'value': False}],
'FEATURE_DISPATCHERD_ENABLED': [{'condition': 'boolean', 'value': False}],
}
FLAGS = {'FEATURE_INDIRECT_NODE_COUNTING_ENABLED': [{'condition': 'boolean', 'value': False}]}
# Dispatcher worker lifetime. If set to None, workers will never be retired
# based on age. Note workers will finish their last task before retiring if
# they are busy when they reach retirement age.
WORKER_MAX_LIFETIME_SECONDS = 14400 # seconds
FLAG_SOURCES = ('flags.sources.SettingsFlagsSource',)

Some files were not shown because too many files have changed in this diff Show More