Compare commits

..

2 Commits

Author SHA1 Message Date
Peter Braun
ed1b5c5519 exclude password fields from skipping if not defined 2026-02-17 12:19:43 +01:00
Peter Braun
d2e51c4124 do not add optional survey fields with empty strings that are not backed by extra_vars 2026-02-17 10:13:33 +01:00
163 changed files with 1323 additions and 10816 deletions

View File

@@ -24,7 +24,7 @@ in as the first entry for your PR title.
##### STEPS TO REPRODUCE AND EXTRA INFO
##### ADDITIONAL INFORMATION
<!---
Include additional information to help people understand the change here.
For bugs that don't have a linked bug report, a step-by-step reproduction

View File

@@ -1,55 +0,0 @@
---
name: Repo Owns Branch
# Reusable workflow that determines whether the current repository
# owns the current branch for push operations.
#
# Ownership rules:
# - ansible/awx owns: devel, feature_*
# - ansible/tower owns: stable-*, release_*
# - workflow_dispatch is always allowed
#
# All other repo/branch combinations are skipped.
on:
workflow_call:
outputs:
should_run:
description: Whether this repo owns the current branch
value: ${{ jobs.check.outputs.should_run }}
jobs:
check:
runs-on: ubuntu-latest
outputs:
should_run: ${{ steps.check.outputs.should_run }}
steps:
- name: Check branch ownership
id: check
run: |
REPO="${{ github.repository }}"
BRANCH="${{ github.ref_name }}"
EVENT="${{ github.event_name }}"
if [[ "$EVENT" == "workflow_dispatch" ]]; then
echo "should_run=true" >> $GITHUB_OUTPUT
echo "Manual trigger — allowed"
exit 0
fi
# ansible/awx owns devel and feature_* branches
if [[ "$REPO" == "ansible/awx" ]] && [[ "$BRANCH" == "devel" || "$BRANCH" == feature_* ]]; then
echo "should_run=true" >> $GITHUB_OUTPUT
echo "Repository '$REPO' owns branch '$BRANCH'"
exit 0
fi
# ansible/tower owns stable-* and release_* branches
if [[ "$REPO" == "ansible/tower" ]] && [[ "$BRANCH" == stable-* || "$BRANCH" == release_* ]]; then
echo "should_run=true" >> $GITHUB_OUTPUT
echo "Repository '$REPO' owns branch '$BRANCH'"
exit 0
fi
echo "should_run=false" >> $GITHUB_OUTPUT
echo "Repository '$REPO' does not own branch '$BRANCH' — skipping"

View File

@@ -4,46 +4,14 @@ env:
LC_ALL: "C.UTF-8" # prevent ERROR: Ansible could not initialize the preferred locale: unsupported locale setting
CI_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DEV_DOCKER_OWNER: ${{ github.repository_owner }}
COMPOSE_TAG: ${{ github.base_ref || github.ref_name || 'devel' }}
COMPOSE_TAG: ${{ github.base_ref || 'devel' }}
UPSTREAM_REPOSITORY_ID: 91594105
on:
pull_request:
push:
branches:
- devel # needed to publish code coverage post-merge
schedule:
- cron: '0 12,18 * * 1-5'
workflow_dispatch: {}
jobs:
trigger-release-branches:
name: "Dispatch CI to release branches"
if: github.event_name == 'schedule'
runs-on: ubuntu-latest
permissions:
actions: write
steps:
- name: Trigger CI on release_4.6
id: dispatch_release_46
continue-on-error: true
run: gh workflow run ci.yml --ref release_4.6
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
- name: Trigger CI on stable-2.6
id: dispatch_stable_26
continue-on-error: true
run: gh workflow run ci.yml --ref stable-2.6
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
- name: Check dispatch results
if: steps.dispatch_release_46.outcome == 'failure' || steps.dispatch_stable_26.outcome == 'failure'
run: |
echo "One or more dispatches failed:"
echo " release_4.6: ${{ steps.dispatch_release_46.outcome }}"
echo " stable-2.6: ${{ steps.dispatch_stable_26.outcome }}"
exit 1
common-tests:
name: ${{ matrix.tests.name }}
runs-on: ubuntu-latest
@@ -94,11 +62,7 @@ jobs:
run: |
if [ -f "reports/coverage.xml" ]; then
sed -i '2i<!-- PR ${{ github.event.pull_request.number }} -->' reports/coverage.xml
echo "Injected PR number ${{ github.event.pull_request.number }} into reports/coverage.xml"
fi
if [ -f "awxkit/coverage.xml" ]; then
sed -i '2i<!-- PR ${{ github.event.pull_request.number }} -->' awxkit/coverage.xml
echo "Injected PR number ${{ github.event.pull_request.number }} into awxkit/coverage.xml"
echo "Injected PR number ${{ github.event.pull_request.number }} into coverage.xml"
fi
- name: Upload test coverage to Codecov
@@ -145,9 +109,7 @@ jobs:
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.tests.name }}-artifacts
path: |
reports/coverage.xml
awxkit/coverage.xml
path: reports/coverage.xml
retention-days: 5
- name: >-
@@ -160,7 +122,7 @@ jobs:
&& github.event_name == 'push'
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id
&& github.ref_name == github.event.repository.default_branch
uses: ansible/gh-action-record-test-results@3784db66a1b7fb3809999a7251c8a7203a7ffbe8
uses: ansible/gh-action-record-test-results@cd5956ead39ec66351d0779470c8cff9638dd2b8
with:
aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}
http-auth-password: >-
@@ -334,7 +296,7 @@ jobs:
&& github.event_name == 'push'
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id
&& github.ref_name == github.event.repository.default_branch
uses: ansible/gh-action-record-test-results@3784db66a1b7fb3809999a7251c8a7203a7ffbe8
uses: ansible/gh-action-record-test-results@cd5956ead39ec66351d0779470c8cff9638dd2b8
with:
aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}
http-auth-password: >-

View File

@@ -12,12 +12,7 @@ on:
- feature_*
- stable-*
jobs:
check-ownership:
uses: ./.github/workflows/_repo-owns-branch.yml
push-development-images:
needs: check-ownership
if: needs.check-ownership.outputs.should_run == 'true'
runs-on: ubuntu-latest
timeout-minutes: 120
permissions:
@@ -35,6 +30,12 @@ jobs:
make-target: awx-kube-buildx
steps:
- name: Skipping build of awx image for non-awx repository
run: |
echo "Skipping build of awx image for non-awx repository"
exit 0
if: matrix.build-targets.image-name == 'awx' && !endsWith(github.repository, '/awx')
- uses: actions/checkout@v4
with:
show-progress: false

View File

@@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 20
permissions:
packages: read
packages: write
contents: read
steps:
- name: Check for each of the lines

View File

@@ -16,16 +16,9 @@ on:
push:
branches:
- devel
- 'stable-2.[6-9]'
- 'stable-2.[1-9][0-9]'
workflow_dispatch: # Allow manual triggering for testing
jobs:
check-ownership:
uses: ./.github/workflows/_repo-owns-branch.yml
sync-openapi-spec:
needs: check-ownership
if: needs.check-ownership.outputs.should_run == 'true'
name: Sync OpenAPI spec to central repo
runs-on: ubuntu-latest
permissions:

View File

@@ -13,12 +13,7 @@ on:
- feature_**
- stable-**
jobs:
check-ownership:
uses: ./.github/workflows/_repo-owns-branch.yml
push:
needs: check-ownership
if: needs.check-ownership.outputs.should_run == 'true'
runs-on: ubuntu-latest
timeout-minutes: 60
permissions:

View File

@@ -1,65 +0,0 @@
---
apiVersion: tekton.dev/v1
kind: PipelineRun
metadata:
name: awx-atf-tests-pull-request
annotations:
build.appstudio.openshift.io/repo: https://github.com/{{repo_owner}}/{{repo_name}}?rev={{revision}}
build.appstudio.redhat.com/commit_sha: '{{revision}}'
build.appstudio.redhat.com/pull_request_number: '{{pull_request_number}}'
build.appstudio.redhat.com/target_branch: '{{target_branch}}'
pipelinesascode.tekton.dev/cancel-in-progress: 'true'
pipelinesascode.tekton.dev/max-keep-runs: "3"
pipelinesascode.tekton.dev/on-comment: "^/run-atf-tests$"
pipelinesascode.tekton.dev/target-namespace: ansible-ci-tenant
labels:
appstudio.openshift.io/application: '{{repo_owner}}'
appstudio.openshift.io/component: '{{repo_owner}}-{{repo_name}}'
pipelines.appstudio.openshift.io/type: build
spec:
timeouts:
pipeline: "8h"
tasks: "7h"
finally: "1h"
pipelineRef:
resolver: bundles
params:
- name: name
value: aap-api-tests
- name: bundle
value: quay.io/aap-ci/tekton-catalog/pipeline/test/aap-api-tests:0.1@sha256:54d9e941748bae94b2154b3b253a985e628751dfa4508a138d9b05f74a3c1ddf
- name: kind
value: pipeline
- name: secret
value: quay-aap-ci-viewer
taskRunTemplate:
serviceAccountName: konflux-integration-runner
params:
- name: git-url
value: "{{source_url}}"
- name: pipeline-github-org
value: "{{repo_owner}}"
- name: pipeline-github-repo
value: "{{repo_name}}"
- name: pipeline-github-target-branch
value: '{{target_branch}}'
- name: pipeline-github-pr-revision
value: "{{revision}}"
- name: pipeline-github-pr-number
value: "{{pull_request_number}}"
- name: aap-dev-component-source-name
value: "controller"
- name: pytest-number-of-parallel-processes
value: "6"
workspaces:
- name: workspace
volumeClaimTemplate:
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 1Gi

View File

@@ -103,12 +103,6 @@ When necessary, remove any AWX containers and images by running the following:
### Pre commit hooks
Install the pre-commit hook before contributing:
```
make pre-commit
```
When you attempt to perform a `git commit` there will be a pre-commit hook that gets run before the commit is allowed to your local repository. For example, python's [black](https://pypi.org/project/black/) will be run to test the formatting of any python files.
While you can use environment variables to skip the pre-commit hooks GitHub will run similar tests and prevent merging of PRs if the tests do not pass.

View File

@@ -10,7 +10,6 @@ KIND_BIN ?= $(shell which kind)
CHROMIUM_BIN=/tmp/chrome-linux/chrome
GIT_REPO_NAME ?= $(shell basename `git rev-parse --show-toplevel`)
GIT_BRANCH ?= $(shell git rev-parse --abbrev-ref HEAD)
GIT_IS_WORKTREE := $(shell test -f .git && echo yes)
MANAGEMENT_COMMAND ?= awx-manage
VERSION ?= $(shell $(PYTHON) tools/scripts/scm_version.py 2> /dev/null)
@@ -107,15 +106,6 @@ else
DOCKER_KUBE_CACHE_FLAG=$(DOCKER_CACHE)
endif
# AWX TUI variables
AWX_HOST ?= https://localhost:8043
AWX_USER ?= admin
AWX_PASSWORD ?= $$(awk -F"'" '/^admin_password:/{print $$2}' tools/docker-compose/_sources/secrets/admin_password.yml 2>/dev/null || echo "admin")
AWX_VERIFY_SSL ?= false
# For git worktree to find the referenced git dir
GIT_COMMON_DIR := $(shell git rev-parse --git-common-dir 2>/dev/null || echo .git)
.PHONY: awx-link clean clean-tmp clean-venv requirements requirements_dev \
update_requirements upgrade_requirements update_requirements_dev \
docker_update_requirements docker_upgrade_requirements docker_update_requirements_dev \
@@ -123,7 +113,7 @@ GIT_COMMON_DIR := $(shell git rev-parse --git-common-dir 2>/dev/null || echo .gi
receiver test test_unit test_coverage coverage_html \
sdist \
VERSION PYTHON_VERSION docker-compose-sources \
pre-commit
.git/hooks/pre-commit
clean-tmp:
rm -rf tmp/
@@ -352,10 +342,11 @@ black: reports
@command -v black >/dev/null 2>&1 || { echo "could not find black on your PATH, you may need to \`pip install black\`, or set AWX_IGNORE_BLACK=1" && exit 1; }
@(set -o pipefail && $@ $(BLACK_ARGS) awx awxkit awx_collection | tee reports/$@.report)
$(GIT_COMMON_DIR)/hooks/pre-commit:
ln -sf ../../pre-commit.sh $(GIT_COMMON_DIR)/hooks/pre-commit
pre-commit: $(GIT_COMMON_DIR)/hooks/pre-commit
.git/hooks/pre-commit:
@echo "if [ -x pre-commit.sh ]; then" > .git/hooks/pre-commit
@echo " ./pre-commit.sh;" >> .git/hooks/pre-commit
@echo "fi" >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit
genschema: awx-link reports
@if [ "$(VENV_BASE)" ]; then \
@@ -530,7 +521,7 @@ ifneq ($(ADMIN_PASSWORD),)
EXTRA_SOURCES_ANSIBLE_OPTS := -e admin_password=$(ADMIN_PASSWORD) $(EXTRA_SOURCES_ANSIBLE_OPTS)
endif
docker-compose-sources:
docker-compose-sources: .git/hooks/pre-commit
@if [ $(MINIKUBE_CONTAINER_GROUP) = true ]; then\
$(ANSIBLE_PLAYBOOK) -i tools/docker-compose/inventory -e minikube_setup=$(MINIKUBE_SETUP) tools/docker-compose-minikube/deploy.yml; \
fi;
@@ -562,7 +553,7 @@ docker-compose: awx/projects docker-compose-sources
$(MAKE) docker-compose-up
docker-compose-up:
$(if $(GIT_IS_WORKTREE),SETUPTOOLS_SCM_PRETEND_VERSION="$(VERSION)") $(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml $(COMPOSE_OPTS) up $(COMPOSE_UP_OPTS) --remove-orphans
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml $(COMPOSE_OPTS) up $(COMPOSE_UP_OPTS) --remove-orphans
docker-compose-down:
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml $(COMPOSE_OPTS) down --remove-orphans
@@ -580,20 +571,6 @@ docker-compose-runtest: awx/projects docker-compose-sources
docker-compose-build-schema: awx/projects docker-compose-sources
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml run --rm --service-ports --no-deps awx_1 make genschema
awx-tui:
@if ! command -v awx-tui > /dev/null 2>&1; then \
$(PYTHON) -m pip install awx-tui; \
fi
@if [ -f "$(HOME)/.config/awx-tui/config.yaml" ]; then \
$(PYTHON) -m awx_tui.main; \
else \
AWX_HOST=$(AWX_HOST) \
AWX_USER=$(AWX_USER) \
AWX_PASSWORD=$(AWX_PASSWORD) \
AWX_VERIFY_SSL=$(AWX_VERIFY_SSL) \
$(PYTHON) -m awx_tui.main --host $(AWX_HOST); \
fi
SCHEMA_DIFF_BASE_FOLDER ?= awx
SCHEMA_DIFF_BASE_BRANCH ?= devel
detect-schema-change: genschema
@@ -604,7 +581,7 @@ detect-schema-change: genschema
validate-openapi-schema: genschema
@echo "Validating OpenAPI schema from schema.json..."
@python3 -c "from openapi_spec_validator import validate; import json; spec = json.load(open('schema.json')); validate(spec); print('✓ Schema is valid')"
@python3 -c "from openapi_spec_validator import validate; import json; spec = json.load(open('schema.json')); validate(spec); print('✓ OpenAPI Schema is valid!')"
docker-compose-clean: awx/projects
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml rm -sf

View File

@@ -131,14 +131,8 @@ class LoggedLoginView(auth_views.LoginView):
class LoggedLogoutView(auth_views.LogoutView):
# Override http_method_names to allow GET requests (Django 5.2+ defaults to POST only)
http_method_names = ["get", "post", "options"]
success_url_allowed_hosts = set(settings.LOGOUT_ALLOWED_HOSTS.split(",")) if settings.LOGOUT_ALLOWED_HOSTS else set()
def get(self, request, *args, **kwargs):
"""Handle GET requests for logout (for backward compatibility)."""
return self.post(request, *args, **kwargs)
def dispatch(self, request, *args, **kwargs):
if is_proxied_request():
# 1) We intentionally don't obey ?next= here, just always redirect to platform login
@@ -272,10 +266,7 @@ class APIView(views.APIView):
response = self.handle_exception(self.__init_request_error__)
if response.status_code == 401:
if response.data and 'detail' in response.data:
if getattr(settings, 'RESOURCE_SERVER__URL', None):
response.data['detail'] += _(' Direct access is not allowed, authenticate via the platform gateway.')
else:
response.data['detail'] += _(' To establish a login session, visit') + ' /api/login/.'
response.data['detail'] += _(' To establish a login session, visit') + ' /api/login/.'
logger.info(status_msg)
else:
logger.warning(status_msg)

View File

@@ -122,6 +122,7 @@ from awx.main.scheduler.task_manager_models import TaskManagerModels
from awx.main.redact import UriCleaner, REPLACE_STR
from awx.main.signals import update_inventory_computed_fields
from awx.main.validators import vars_validate_or_raise
from awx.api.versioning import reverse
@@ -174,8 +175,8 @@ SUMMARIZABLE_FK_FIELDS = {
'workflow_approval': DEFAULT_SUMMARY_FIELDS + ('timeout',),
'schedule': DEFAULT_SUMMARY_FIELDS + ('next_run',),
'unified_job_template': DEFAULT_SUMMARY_FIELDS + ('unified_job_type',),
# last_job and last_job_host_summary are derived from JobHostSummary in HostSerializer,
# not from the stale FK fields on Host.
'last_job': DEFAULT_SUMMARY_FIELDS + ('finished', 'status', 'failed', 'license_error', 'canceled_on'),
'last_job_host_summary': DEFAULT_SUMMARY_FIELDS + ('failed',),
'last_update': DEFAULT_SUMMARY_FIELDS + ('status', 'failed', 'license_error'),
'current_update': DEFAULT_SUMMARY_FIELDS + ('status', 'failed', 'license_error'),
'current_job': DEFAULT_SUMMARY_FIELDS + ('status', 'failed', 'license_error'),
@@ -961,32 +962,14 @@ class UnifiedJobSerializer(BaseSerializer):
class UnifiedJobListSerializer(UnifiedJobSerializer):
# these fields can be included optionally in the response
OPTIONAL_INCLUDE_FIELDS = frozenset({'artifacts', 'extra_vars'})
# these fields are stripped from the response
_STRIPPED_FIELDS = frozenset({'job_args', 'job_cwd', 'job_env', 'result_traceback', 'event_processing_finished', 'artifacts', 'extra_vars'})
class Meta:
fields = ('*', '-job_args', '-job_cwd', '-job_env', '-result_traceback', '-event_processing_finished', '-artifacts', '-extra_vars')
# processes the include query param if present
def _requested_includes(self):
request = self.context.get('request')
if request is None:
return frozenset()
raw = request.query_params.get('include', '')
requested = {name.strip() for name in raw.split(',') if name.strip()}
# only allow the fields listed in OPTIONAL_INCLUDE_FIELDS
return frozenset(requested) & self.OPTIONAL_INCLUDE_FIELDS
fields = ('*', '-job_args', '-job_cwd', '-job_env', '-result_traceback', '-event_processing_finished', '-artifacts')
def get_field_names(self, declared_fields, info):
field_names = super(UnifiedJobListSerializer, self).get_field_names(declared_fields, info)
# Meta multiple inheritance and -field_name options don't seem to be
# taking effect above, so remove the undesired fields here.
strip = self._STRIPPED_FIELDS - self._requested_includes()
return tuple(x for x in field_names if x not in strip)
return tuple(x for x in field_names if x not in ('job_args', 'job_cwd', 'job_env', 'result_traceback', 'event_processing_finished', 'artifacts'))
def get_types(self):
if type(self) is UnifiedJobListSerializer:
@@ -1039,7 +1022,7 @@ class UnifiedJobStdoutSerializer(UnifiedJobSerializer):
class UserSerializer(BaseSerializer):
password = serializers.CharField(required=False, default='', allow_blank=True, help_text=_('Field used to change the password.'))
password = serializers.CharField(required=False, default='', help_text=_('Field used to change the password.'))
is_system_auditor = serializers.BooleanField(default=False)
show_capabilities = ['edit', 'delete']
@@ -1855,35 +1838,19 @@ class HostSerializer(BaseSerializerWithVariables):
res['ansible_facts'] = self.reverse('api:host_ansible_facts_detail', kwargs={'pk': obj.instance_id})
if obj.inventory:
res['inventory'] = self.reverse('api:inventory_detail', kwargs={'pk': obj.inventory.pk})
last_summary = obj.latest_summary
if last_summary:
res['last_job_host_summary'] = self.reverse('api:job_host_summary_detail', kwargs={'pk': last_summary.pk})
if last_summary.job_id:
res['last_job'] = self.reverse('api:job_detail', kwargs={'pk': last_summary.job_id})
if obj.last_job:
res['last_job'] = self.reverse('api:job_detail', kwargs={'pk': obj.last_job.pk})
if obj.last_job_host_summary:
res['last_job_host_summary'] = self.reverse('api:job_host_summary_detail', kwargs={'pk': obj.last_job_host_summary.pk})
return res
def get_summary_fields(self, obj):
d = super(HostSerializer, self).get_summary_fields(obj)
last_summary = obj.latest_summary
if last_summary:
d['last_job_host_summary'] = OrderedDict()
d['last_job_host_summary']['id'] = last_summary.id
d['last_job_host_summary']['failed'] = last_summary.failed
try:
last_job = last_summary.job
d['last_job'] = OrderedDict()
for field in DEFAULT_SUMMARY_FIELDS + ('finished', 'status', 'failed', 'canceled_on'):
fval = getattr(last_job, field, None)
if fval is not None:
d['last_job'][field] = fval
if last_job.job_template:
d['last_job']['job_template_id'] = last_job.job_template.id
d['last_job']['job_template_name'] = last_job.job_template.name
except ObjectDoesNotExist:
pass
else:
d.pop('last_job', None)
d.pop('last_job_host_summary', None)
try:
d['last_job']['job_template_id'] = obj.last_job.job_template.id
d['last_job']['job_template_name'] = obj.last_job.job_template.name
except (KeyError, AttributeError):
pass
if has_model_field_prefetched(obj, 'groups'):
group_list = sorted([{'id': g.id, 'name': g.name} for g in obj.groups.all()], key=lambda x: x['id'])[:5]
else:
@@ -1958,16 +1925,14 @@ class HostSerializer(BaseSerializerWithVariables):
return ret
if 'inventory' in ret and not obj.inventory:
ret['inventory'] = None
last_summary = obj.latest_summary
if 'last_job' in ret:
ret['last_job'] = last_summary.job_id if last_summary else None
if 'last_job_host_summary' in ret:
ret['last_job_host_summary'] = last_summary.pk if last_summary else None
if 'last_job' in ret and not obj.last_job:
ret['last_job'] = None
if 'last_job_host_summary' in ret and not obj.last_job_host_summary:
ret['last_job_host_summary'] = None
return ret
def get_has_active_failures(self, obj):
last_summary = obj.latest_summary
return bool(last_summary and last_summary.failed)
return bool(obj.last_job_host_summary and obj.last_job_host_summary.failed)
def get_has_inventory_sources(self, obj):
return obj.inventory_sources.exists()
@@ -2114,17 +2079,9 @@ class BulkHostCreateSerializer(serializers.Serializer):
if request and not request.user.is_superuser:
if request.user not in inv.admin_role:
raise serializers.ValidationError(_(f'Inventory with id {inv.id} not found or lack permissions to add hosts.'))
# Performance optimization (AAP-67978): Instead of loading ALL host names from
# the inventory, only check if the specific new names already exist in the database.
current_hostnames = set(inv.hosts.values_list('name', flat=True))
new_names = [host['name'] for host in attrs['hosts']]
new_name_counts = Counter(new_names)
duplicates_in_new = [name for name, count in new_name_counts.items() if count > 1]
unique_new_names = list(new_name_counts.keys())
existing_duplicates = list(Host.objects.filter(inventory=inv, name__in=unique_new_names).values_list('name', flat=True))
duplicate_new_names = list(set(duplicates_in_new + existing_duplicates))
duplicate_new_names = [n for n in new_names if n in current_hostnames or new_names.count(n) > 1]
if duplicate_new_names:
raise serializers.ValidationError(_(f'Hostnames must be unique in an inventory. Duplicates found: {duplicate_new_names}'))
@@ -2975,19 +2932,6 @@ class CredentialTypeSerializer(BaseSerializer):
field['label'] = _(field['label'])
if 'help_text' in field:
field['help_text'] = _(field['help_text'])
# Deep copy inputs to avoid modifying the original model data
inputs = value.get('inputs')
if not isinstance(inputs, dict):
inputs = {}
value['inputs'] = copy.deepcopy(inputs)
fields = value['inputs'].get('fields', [])
if not isinstance(fields, list):
fields = []
# Normalize fields and filter out internal fields
value['inputs']['fields'] = [f for f in fields if not f.get('internal')]
return value
def filter_field_metadata(self, fields, method):
@@ -4178,28 +4122,9 @@ class LaunchConfigurationBaseSerializer(BaseSerializer):
attrs['extra_data'][key] = db_extra_data[key]
# Build unsaved version of this config, use it to detect prompts errors
# Capture keys before _build_mock_obj pops pseudo-fields from attrs
incoming_attr_keys = set(attrs.keys())
mock_obj = self._build_mock_obj(attrs)
ask_mapping_keys = set(ujt.get_ask_mapping().keys())
requested_prompt_fields = incoming_attr_keys & ask_mapping_keys
if 'extra_data' in incoming_attr_keys:
requested_prompt_fields.add('extra_vars')
requested_prompt_fields.add('survey_passwords')
# prompts_dict() pulls persisted M2M state (labels, credentials,
# instance_groups) via the instance pk. Only re-validate the full prompt
# state when the caller is switching the underlying template; otherwise
# restrict validation to the fields the request explicitly provided.
if 'unified_job_template' in attrs:
prompts_to_validate = mock_obj.prompts_dict()
elif requested_prompt_fields:
prompts_to_validate = {k: v for k, v in mock_obj.prompts_dict().items() if k in requested_prompt_fields}
else:
prompts_to_validate = None
if prompts_to_validate is not None:
accepted, rejected, errors = ujt._accept_or_ignore_job_kwargs(_exclude_errors=self.exclude_errors, **prompts_to_validate)
if set(list(ujt.get_ask_mapping().keys()) + ['extra_data']) & set(attrs.keys()):
accepted, rejected, errors = ujt._accept_or_ignore_job_kwargs(_exclude_errors=self.exclude_errors, **mock_obj.prompts_dict())
else:
# Only perform validation of prompts if prompts fields are provided
errors = {}

View File

@@ -1,4 +1,4 @@
---
collections:
- name: ansible.receptor
version: 2.0.8
version: 2.0.6

View File

@@ -14,14 +14,13 @@ import sys
import time
from base64 import b64encode
from collections import OrderedDict
from jwt import decode as _jwt_decode
from urllib3.exceptions import ConnectTimeoutError
# Django
from django.conf import settings
from django.core.exceptions import FieldError, ObjectDoesNotExist
from django.db.models import Q, Sum, Count, Subquery, OuterRef
from django.db.models import Q, Sum, Count
from django.db import IntegrityError, ProgrammingError, transaction, connection
from django.db.models.fields.related import ManyToManyField, ForeignKey
from django.db.models.functions import Trunc
@@ -56,16 +55,10 @@ from wsgiref.util import FileWrapper
from drf_spectacular.utils import extend_schema_view, extend_schema
# django-ansible-base
from ansible_base.lib.utils.requests import get_remote_hosts
from ansible_base.rbac.models import RoleEvaluation
from ansible_base.lib.utils.schema import extend_schema_if_available
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
# flags
from flags.state import flag_enabled
# AWX
from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims
from awx.main.tasks.system import send_notifications, update_inventory_computed_fields
from awx.main.access import get_user_queryset
from awx.api.generics import (
@@ -104,6 +97,7 @@ from awx.main.utils import (
from awx.main.utils.encryption import encrypt_value
from awx.main.utils.filters import SmartFilter
from awx.main.utils.plugins import compute_cloud_inventory_sources
from awx.main.utils.proxy import get_first_remote_host_from_headers
from awx.main.utils.common import memoize
from awx.main.redact import UriCleaner
from awx.api.permissions import (
@@ -127,7 +121,6 @@ from awx.api.views.mixin import (
RelatedJobsPreventDeleteMixin,
UnifiedJobDeletionMixin,
NoTruncateMixin,
UnifiedJobIncludeMixin,
)
from awx.api.pagination import UnifiedJobEventPagination
from awx.main.utils import set_environ
@@ -210,12 +203,11 @@ class DashboardView(APIView):
groups_inventory_failed = models.Group.objects.filter(inventory_sources__last_job_failed=True).count()
data['groups'] = {'url': reverse('api:group_list', request=request), 'total': user_groups.count(), 'inventory_failed': groups_inventory_failed}
user_hosts = get_user_queryset(request.user, models.Host).exclude(inventory__kind='constructed')
latest_summary_failed = Subquery(models.JobHostSummary.objects.filter(host_id=OuterRef('pk')).order_by('-id').values('failed')[:1])
user_hosts_failed = user_hosts.annotate(_latest_failed=latest_summary_failed).filter(_latest_failed=True)
user_hosts = get_user_queryset(request.user, models.Host)
user_hosts_failed = user_hosts.filter(last_job_host_summary__failed=True)
data['hosts'] = {
'url': reverse('api:host_list', request=request),
'failures_url': reverse('api:host_list', request=request) + "?last_job_host_summary__failed=True",
'total': user_hosts.count(),
'failed': user_hosts_failed.count(),
}
@@ -802,11 +794,22 @@ class TeamRolesList(SubListAttachDetachAPIView):
data = dict(msg=_("You cannot grant system-level permissions to a team."))
return Response(data, status=status.HTTP_400_BAD_REQUEST)
if not request.data.get('disassociate'):
team = get_object_or_404(models.Team, pk=self.kwargs['pk'])
content_object = role.content_object
if hasattr(content_object, 'validate_role_assignment'):
content_object.validate_role_assignment(team, role_definition=None, requesting_user=request.user)
team = get_object_or_404(models.Team, pk=self.kwargs['pk'])
credential_content_type = ContentType.objects.get_for_model(models.Credential)
if role.content_type == credential_content_type:
if not role.content_object.organization:
data = dict(
msg=_("You cannot grant access to a credential that is not assigned to an organization (private credentials cannot be assigned to teams)")
)
return Response(data, status=status.HTTP_400_BAD_REQUEST)
elif role.content_object.organization.id != team.organization.id:
if not request.user.is_superuser:
data = dict(
msg=_(
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
)
)
return Response(data, status=status.HTTP_400_BAD_REQUEST)
return super(TeamRolesList, self).post(request, *args, **kwargs)
@@ -1265,12 +1268,19 @@ class UserRolesList(SubListAttachDetachAPIView):
if not sub_id:
return super(UserRolesList, self).post(request)
if not request.data.get('disassociate'):
role = get_object_or_400(models.Role, pk=sub_id)
user = get_object_or_400(models.User, pk=self.kwargs['pk'])
content_object = role.content_object
if hasattr(content_object, 'validate_role_assignment'):
content_object.validate_role_assignment(user, role_definition=None, requesting_user=request.user)
user = get_object_or_400(models.User, pk=self.kwargs['pk'])
role = get_object_or_400(models.Role, pk=sub_id)
content_types = ContentType.objects.get_for_models(models.Organization, models.Team, models.Credential) # dict of {model: content_type}
credential_content_type = content_types[models.Credential]
if role.content_type == credential_content_type:
if 'disassociate' not in request.data and role.content_object.organization and user not in role.content_object.organization.member_role:
data = dict(msg=_("You cannot grant credential access to a user not in the credentials' organization"))
return Response(data, status=status.HTTP_400_BAD_REQUEST)
if not role.content_object.organization and not request.user.is_superuser:
data = dict(msg=_("You cannot grant private credential access to another user"))
return Response(data, status=status.HTTP_400_BAD_REQUEST)
return super(UserRolesList, self).post(request, *args, **kwargs)
@@ -1585,175 +1595,7 @@ class CredentialCopy(CopyAPIView):
resource_purpose = 'copy of a credential'
class OIDCCredentialTestMixin:
"""
Mixin to add OIDC workload identity token support to credential test endpoints.
This mixin provides methods to handle OIDC-enabled external credentials that use
workload identity tokens for authentication.
"""
@staticmethod
def _get_workload_identity_token(job_template: models.JobTemplate, audience: str) -> str:
"""Generate a workload identity token for a job template.
Args:
job_template: The JobTemplate instance to generate claims for
audience: The JWT audience claim value
Returns:
str: The generated JWT token
"""
claims = {
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: job_template.organization.name,
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: job_template.organization.id,
AutomationControllerJobScope.CLAIM_PROJECT_NAME: job_template.project.name,
AutomationControllerJobScope.CLAIM_PROJECT_ID: job_template.project.id,
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: job_template.name,
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: job_template.id,
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: job_template.playbook,
}
return retrieve_workload_identity_jwt_with_claims(
claims=claims,
audience=audience,
scope=AutomationControllerJobScope.name,
)
@staticmethod
def _decode_jwt_payload_for_display(jwt_token):
"""Decode JWT payload for display purposes only (signature not verified).
This is safe because the JWT was just created by AWX and is only decoded
to show the user what claims are being sent to the external system.
The external system will perform proper signature verification.
Args:
jwt_token: The JWT token to decode
Returns:
dict: The decoded JWT payload
"""
return _jwt_decode(jwt_token, algorithms=["RS256"], options={"verify_signature": False}) # NOSONAR python:S5659
def _has_workload_identity_token(self, credential_type_inputs):
"""Check if credential type has an internal workload_identity_token field.
Args:
credential_type_inputs: The inputs dict from a credential type
Returns:
bool: True if the credential type has a workload_identity_token field marked as internal
"""
fields = credential_type_inputs.get('fields', []) if isinstance(credential_type_inputs, dict) else []
return any(field.get('internal') and field.get('id') == 'workload_identity_token' for field in fields)
def _validate_and_get_job_template(self, job_template_id):
"""Validate job template ID and return the JobTemplate instance.
Args:
job_template_id: The job template ID from metadata
Returns:
JobTemplate instance
Raises:
ParseError: If job_template_id is invalid or not found
"""
if job_template_id is None:
raise ParseError(_('Job template ID is required.'))
try:
return models.JobTemplate.objects.get(id=int(job_template_id))
except ValueError:
raise ParseError(_('Job template ID must be an integer.'))
except models.JobTemplate.DoesNotExist:
raise ParseError(_('Job template with ID %(id)s does not exist.') % {'id': job_template_id})
def _handle_oidc_credential_test(self, backend_kwargs):
"""
Handle OIDC workload identity token generation for external credential test endpoints.
This method should only be called when FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is enabled
and the credential type has a workload_identity_token field.
Args:
backend_kwargs: The kwargs dict to pass to the backend (will be modified in place)
Returns:
dict: Response body containing details with the sent JWT payload
Raises:
PermissionDenied: If user lacks access to the job template (re-raised for 403 response)
All other exceptions are caught and converted to 400 responses with error details.
Modifies backend_kwargs in place to add workload_identity_token.
"""
# Validate job template
job_template_id = backend_kwargs.pop('job_template_id', None)
job_template = self._validate_and_get_job_template(job_template_id)
# Check user access
if not self.request.user.can_access(models.JobTemplate, 'start', job_template):
raise PermissionDenied(_('You do not have access to job template with id: %(id)s.') % {'id': job_template.id})
# Generate workload identity token
jwt_token = self._get_workload_identity_token(job_template, backend_kwargs.get('url'))
backend_kwargs['workload_identity_token'] = jwt_token
return {'details': {'sent_jwt_payload': self._decode_jwt_payload_for_display(jwt_token)}}
def _call_backend_with_error_handling(self, plugin, backend_kwargs, response_body):
"""Call credential backend and handle errors."""
try:
with set_environ(**settings.AWX_TASK_ENV):
plugin.backend(**backend_kwargs)
return Response(response_body, status=status.HTTP_202_ACCEPTED)
except requests.exceptions.HTTPError as exc:
message = self._extract_http_error_message(exc)
self._add_error_to_response(response_body, message)
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc:
message = self._extract_generic_error_message(exc)
self._add_error_to_response(response_body, message)
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
@staticmethod
def _extract_http_error_message(exc):
"""Extract error message from HTTPError, checking response JSON and text."""
message = str(exc)
if not hasattr(exc, 'response') or exc.response is None:
return message
try:
error_data = exc.response.json()
if 'errors' in error_data and error_data['errors']:
return ', '.join(error_data['errors'])
if 'error' in error_data:
return error_data['error']
except (ValueError, KeyError):
if exc.response.text:
return exc.response.text
return message
@staticmethod
def _extract_generic_error_message(exc):
"""Extract error message from exception, handling ConnectTimeoutError specially."""
message = str(exc) if str(exc) else exc.__class__.__name__
for arg in getattr(exc, 'args', []):
if isinstance(getattr(arg, 'reason', None), ConnectTimeoutError):
return str(arg.reason)
return message
@staticmethod
def _add_error_to_response(response_body, message):
"""Add error message to both 'detail' and 'details.error_message' fields."""
response_body['detail'] = message
if 'details' in response_body:
response_body['details']['error_message'] = message
class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
class CredentialExternalTest(SubDetailAPIView):
"""
Test updates to the input values and metadata of an external credential
before saving them.
@@ -1773,8 +1615,6 @@ class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
It does not support standard credential types such as Machine, SCM, and Cloud."""})
def post(self, request, *args, **kwargs):
obj = self.get_object()
if obj.credential_type.kind != 'external':
raise ParseError(_('Credential is not testable.'))
backend_kwargs = {}
for field_name, value in obj.inputs.items():
backend_kwargs[field_name] = obj.get_input(field_name)
@@ -1782,22 +1622,23 @@ class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
if value != '$encrypted$':
backend_kwargs[field_name] = value
backend_kwargs.update(request.data.get('metadata', {}))
# Handle OIDC workload identity token generation if enabled
response_body = {}
if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.credential_type.inputs):
try:
oidc_response_body = self._handle_oidc_credential_test(backend_kwargs)
response_body.update(oidc_response_body)
except PermissionDenied:
raise
except Exception as exc:
error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc)
response_body['detail'] = error_message
response_body['details'] = {'error_message': error_message}
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
return self._call_backend_with_error_handling(obj.credential_type.plugin, backend_kwargs, response_body)
try:
with set_environ(**settings.AWX_TASK_ENV):
obj.credential_type.plugin.backend(**backend_kwargs)
return Response({}, status=status.HTTP_202_ACCEPTED)
except requests.exceptions.HTTPError:
message = """Test operation is not supported for credential type {}.
This endpoint only supports credentials that connect to
external secret management systems such as CyberArk, HashiCorp
Vault, or cloud-based secret managers.""".format(obj.credential_type.kind)
return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc:
message = exc.__class__.__name__
exc_args = getattr(exc, 'args', [])
for a in exc_args:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
class CredentialInputSourceDetail(RetrieveUpdateDestroyAPIView):
@@ -1827,7 +1668,7 @@ class CredentialInputSourceSubList(SubListCreateAPIView):
parent_key = 'target_credential'
class CredentialTypeExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
class CredentialTypeExternalTest(SubDetailAPIView):
"""
Test a complete set of input values for an external credential before
saving it.
@@ -1842,26 +1683,21 @@ class CredentialTypeExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
@extend_schema_if_available(extensions={"x-ai-description": "Test a complete set of input values for an external credential"})
def post(self, request, *args, **kwargs):
obj = self.get_object()
if obj.kind != 'external':
raise ParseError(_('Credential type is not testable.'))
backend_kwargs = request.data.get('inputs', {})
backend_kwargs.update(request.data.get('metadata', {}))
# Handle OIDC workload identity token generation if enabled
response_body = {}
if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.inputs):
try:
oidc_response_body = self._handle_oidc_credential_test(backend_kwargs)
response_body.update(oidc_response_body)
except PermissionDenied:
raise
except Exception as exc:
error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc)
response_body['detail'] = error_message
response_body['details'] = {'error_message': error_message}
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
return self._call_backend_with_error_handling(obj.plugin, backend_kwargs, response_body)
try:
obj.plugin.backend(**backend_kwargs)
return Response({}, status=status.HTTP_202_ACCEPTED)
except requests.exceptions.HTTPError as exc:
message = 'HTTP {}'.format(exc.response.status_code)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc:
message = exc.__class__.__name__
args_exc = getattr(exc, 'args', [])
for a in args_exc:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
class HostRelatedSearchMixin(object):
@@ -1927,7 +1763,7 @@ class HostList(HostRelatedSearchMixin, ListCreateAPIView):
if filter_string:
filter_qs = SmartFilter.query_from_string(filter_string)
qs &= filter_qs
return qs.distinct().with_latest_summary_id()
return qs.distinct()
def list(self, *args, **kwargs):
try:
@@ -1942,9 +1778,6 @@ class HostDetail(RelatedJobsPreventDeleteMixin, RetrieveUpdateDestroyAPIView):
serializer_class = serializers.HostSerializer
resource_purpose = 'host detail'
def get_queryset(self):
return super().get_queryset().with_latest_summary_id()
@extend_schema_if_available(extensions={"x-ai-description": "Delete a host"})
def delete(self, request, *args, **kwargs):
if self.get_object().inventory.pending_deletion:
@@ -1978,9 +1811,6 @@ class InventoryHostsList(HostRelatedSearchMixin, SubListCreateAttachDetachAPIVie
filter_read_permission = False
resource_purpose = 'hosts of an inventory'
def get_queryset(self):
return super().get_queryset().with_latest_summary_id()
class HostGroupsList(SubListCreateAttachDetachAPIView):
'''the list of groups a host is directly a member of'''
@@ -2164,9 +1994,6 @@ class GroupHostsList(HostRelatedSearchMixin, SubListCreateAttachDetachAPIView):
relationship = 'hosts'
resource_purpose = 'hosts of a group'
def get_queryset(self):
return super().get_queryset().with_latest_summary_id()
def update_raw_data(self, data):
data.pop('inventory', None)
return super(GroupHostsList, self).update_raw_data(data)
@@ -2198,7 +2025,7 @@ class GroupAllHostsList(HostRelatedSearchMixin, SubListAPIView):
self.check_parent_access(parent)
qs = self.request.user.get_queryset(self.model).distinct() # need distinct for '&' operator
sublist_qs = parent.all_hosts.distinct()
return (qs & sublist_qs).with_latest_summary_id()
return qs & sublist_qs
class GroupInventorySourcesList(SubListAPIView):
@@ -2491,9 +2318,6 @@ class InventorySourceHostsList(HostRelatedSearchMixin, SubListDestroyAPIView):
check_sub_obj_permission = False
resource_purpose = 'hosts of an inventory source'
def get_queryset(self):
return super().get_queryset().with_latest_summary_id()
def perform_list_destroy(self, instance_list):
inv_source = self.get_parent_object()
with ignore_inventory_computed_fields():
@@ -3053,7 +2877,8 @@ class JobTemplateCallback(GenericAPIView):
host for the current request.
"""
# Find the list of remote host names/IPs to check.
remote_hosts = set(get_remote_hosts(self.request))
# Only consider the first entry from each header (for comma-separated values like X-Forwarded-For)
remote_hosts = get_first_remote_host_from_headers(self.request, settings.REMOTE_HOST_HEADERS)
# Add the reverse lookup of IP addresses.
for rh in list(remote_hosts):
try:
@@ -3851,7 +3676,7 @@ class SystemJobTemplateNotificationTemplatesSuccessList(SystemJobTemplateNotific
resource_purpose = 'notification templates triggered on system job success'
class JobList(UnifiedJobIncludeMixin, ListAPIView):
class JobList(ListAPIView):
model = models.Job
serializer_class = serializers.JobListSerializer
resource_purpose = 'jobs'
@@ -4568,7 +4393,7 @@ class UnifiedJobTemplateList(ListAPIView):
resource_purpose = 'unified job templates'
class UnifiedJobList(UnifiedJobIncludeMixin, ListAPIView):
class UnifiedJobList(ListAPIView):
model = models.UnifiedJob
serializer_class = serializers.UnifiedJobListSerializer
search_fields = ('description', 'name', 'job__playbook')
@@ -4871,12 +4696,19 @@ class RoleUsersList(SubListAttachDetachAPIView):
if not sub_id:
return super(RoleUsersList, self).post(request)
if not request.data.get('disassociate'):
user = get_object_or_400(models.User, pk=sub_id)
role = self.get_parent_object()
content_object = role.content_object
if hasattr(content_object, 'validate_role_assignment'):
content_object.validate_role_assignment(user, role_definition=None, requesting_user=request.user)
user = get_object_or_400(models.User, pk=sub_id)
role = self.get_parent_object()
content_types = ContentType.objects.get_for_models(models.Organization, models.Team, models.Credential) # dict of {model: content_type}
credential_content_type = content_types[models.Credential]
if role.content_type == credential_content_type:
if 'disassociate' not in request.data and role.content_object.organization and user not in role.content_object.organization.member_role:
data = dict(msg=_("You cannot grant credential access to a user not in the credentials' organization"))
return Response(data, status=status.HTTP_400_BAD_REQUEST)
if not role.content_object.organization and not request.user.is_superuser:
data = dict(msg=_("You cannot grant private credential access to another user"))
return Response(data, status=status.HTTP_400_BAD_REQUEST)
return super(RoleUsersList, self).post(request, *args, **kwargs)
@@ -4909,6 +4741,24 @@ class RoleTeamsList(SubListAttachDetachAPIView):
data = dict(msg=_("You cannot assign an Organization participation role as a child role for a Team."))
return Response(data, status=status.HTTP_400_BAD_REQUEST)
credential_content_type = ContentType.objects.get_for_model(models.Credential)
if role.content_type == credential_content_type:
# Private credentials (no organization) are never allowed for teams
if not role.content_object.organization:
data = dict(
msg=_("You cannot grant access to a credential that is not assigned to an organization (private credentials cannot be assigned to teams)")
)
return Response(data, status=status.HTTP_400_BAD_REQUEST)
# Cross-organization credentials are only allowed for superusers
elif role.content_object.organization.id != team.organization.id:
if not request.user.is_superuser:
data = dict(
msg=_(
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
)
)
return Response(data, status=status.HTTP_400_BAD_REQUEST)
action = 'attach'
if request.data.get('disassociate', None):
action = 'unattach'
@@ -4917,11 +4767,6 @@ class RoleTeamsList(SubListAttachDetachAPIView):
data = dict(msg=_("You cannot grant system-level permissions to a team."))
return Response(data, status=status.HTTP_400_BAD_REQUEST)
if action == 'attach':
content_object = role.content_object
if hasattr(content_object, 'validate_role_assignment'):
content_object.validate_role_assignment(team, role_definition=None, requesting_user=request.user)
if not request.user.can_access(self.parent_model, action, role, team, self.relationship, request.data, skip_sub_obj_read_check=False):
raise PermissionDenied()
if request.data.get('disassociate', None):

View File

@@ -9,7 +9,7 @@ from django.utils import translation
from awx.api.generics import APIView, Response
from awx.api.permissions import AnalyticsPermission
from awx.api.versioning import reverse
from awx.main.utils import get_awx_version, set_environ
from awx.main.utils import get_awx_version
from awx.main.utils.analytics_proxy import OIDCClient
from rest_framework import status
@@ -49,6 +49,7 @@ class GetNotAllowedMixin(object):
class AnalyticsRootView(APIView):
permission_classes = (AnalyticsPermission,)
name = _('Automation Analytics')
swagger_topic = 'Automation Analytics'
resource_purpose = 'automation analytics endpoints'
@extend_schema_if_available(extensions={"x-ai-description": "A list of additional API endpoints related to analytics"})
@@ -210,32 +211,31 @@ class AnalyticsGenericView(APIView):
return self._error_response(ERROR_UNSUPPORTED_METHOD, method, remote=False, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
url = self._get_analytics_url(request.path)
using_subscriptions_credentials = False
with set_environ(**settings.AWX_TASK_ENV):
try:
rh_user = getattr(settings, 'REDHAT_USERNAME', None)
rh_password = getattr(settings, 'REDHAT_PASSWORD', None)
if not (rh_user and rh_password):
rh_user = self._get_setting('SUBSCRIPTIONS_CLIENT_ID', None, ERROR_MISSING_USER)
rh_password = self._get_setting('SUBSCRIPTIONS_CLIENT_SECRET', None, ERROR_MISSING_PASSWORD)
using_subscriptions_credentials = True
try:
rh_user = getattr(settings, 'REDHAT_USERNAME', None)
rh_password = getattr(settings, 'REDHAT_PASSWORD', None)
if not (rh_user and rh_password):
rh_user = self._get_setting('SUBSCRIPTIONS_CLIENT_ID', None, ERROR_MISSING_USER)
rh_password = self._get_setting('SUBSCRIPTIONS_CLIENT_SECRET', None, ERROR_MISSING_PASSWORD)
using_subscriptions_credentials = True
client = OIDCClient(rh_user, rh_password)
response = client.make_request(
method,
url,
headers=headers,
verify=settings.INSIGHTS_CERT_PATH,
params=getattr(request, 'query_params', {}),
json=getattr(request, 'data', {}),
timeout=(31, 31),
)
except requests.RequestException:
# subscriptions credentials are not valid for basic auth, so just return 401
if using_subscriptions_credentials:
response = Response(status=status.HTTP_401_UNAUTHORIZED)
else:
logger.error("Automation Analytics API request failed, trying base auth method")
response = self._base_auth_request(request, method, url, rh_user, rh_password, headers)
client = OIDCClient(rh_user, rh_password)
response = client.make_request(
method,
url,
headers=headers,
verify=settings.INSIGHTS_CERT_PATH,
params=getattr(request, 'query_params', {}),
json=getattr(request, 'data', {}),
timeout=(31, 31),
)
except requests.RequestException:
# subscriptions credentials are not valid for basic auth, so just return 401
if using_subscriptions_credentials:
response = Response(status=status.HTTP_401_UNAUTHORIZED)
else:
logger.error("Automation Analytics API request failed, trying base auth method")
response = self._base_auth_request(request, method, url, rh_user, rh_password, headers)
#
# Missing or wrong user/pass
#
@@ -306,6 +306,7 @@ class AnalyticsAuthorizedView(AnalyticsGenericListView):
class AnalyticsReportsList(GetNotAllowedMixin, AnalyticsGenericListView):
name = _("Reports")
swagger_topic = "Automation Analytics"
resource_purpose = 'automation analytics reports'

View File

@@ -212,9 +212,3 @@ class NoTruncateMixin(object):
if self.request.query_params.get('no_truncate'):
context.update(no_truncate=True)
return context
class UnifiedJobIncludeMixin(object):
# Reserve the name 'include' so we can use it as a query param. Otherwise, the rest-filters backend
# would treat it as a model field lookup.
rest_filters_reserved_names = ('include',)

View File

@@ -344,22 +344,13 @@ class ApiV2ConfigView(APIView):
become_methods=PRIVILEGE_ESCALATION_METHODS,
)
# Check superuser/auditor first
if request.user.is_superuser or request.user.is_system_auditor:
has_org_access = True
else:
# Single query checking all three organization role types at once
has_org_access = (
(
Organization.access_qs(request.user, 'change')
| Organization.access_qs(request.user, 'audit')
| Organization.access_qs(request.user, 'add_project')
)
.distinct()
.exists()
)
if has_org_access:
if (
request.user.is_superuser
or request.user.is_system_auditor
or Organization.accessible_objects(request.user, 'admin_role').exists()
or Organization.accessible_objects(request.user, 'auditor_role').exists()
or Organization.accessible_objects(request.user, 'project_admin_role').exists()
):
data.update(
dict(
project_base_dir=settings.PROJECTS_ROOT,
@@ -367,10 +358,8 @@ class ApiV2ConfigView(APIView):
custom_virtualenvs=get_custom_venv_choices(),
)
)
else:
# Only check JobTemplate access if org check failed
if JobTemplate.accessible_objects(request.user, 'admin_role').exists():
data['custom_virtualenvs'] = get_custom_venv_choices()
elif JobTemplate.accessible_objects(request.user, 'admin_role').exists():
data['custom_virtualenvs'] = get_custom_venv_choices()
return Response(data)

View File

@@ -17,7 +17,7 @@ from awx.api import serializers
from awx.api.generics import APIView, GenericAPIView
from awx.api.permissions import WebhookKeyPermission
from awx.main.models import Job, JobTemplate, WorkflowJob, WorkflowJobTemplate
from awx.main.utils.common import get_job_variable_prefixes
from awx.main.constants import JOB_VARIABLE_PREFIXES
logger = logging.getLogger('awx.api.views.webhooks')
@@ -166,7 +166,7 @@ class WebhookReceiverBase(APIView):
'extra_vars': {},
}
for name in get_job_variable_prefixes():
for name in JOB_VARIABLE_PREFIXES:
kwargs['extra_vars']['{}_webhook_event_type'.format(name)] = event_type
kwargs['extra_vars']['{}_webhook_event_guid'.format(name)] = event_guid
kwargs['extra_vars']['{}_webhook_event_ref'.format(name)] = event_ref

View File

@@ -897,6 +897,8 @@ class HostAccess(BaseAccess):
'created_by',
'modified_by',
'inventory',
'last_job__job_template',
'last_job_host_summary__job',
)
prefetch_related = ('groups', 'inventory_sources')

View File

@@ -8,7 +8,6 @@ import pathlib
import shutil
import tarfile
import tempfile
from urllib.parse import urlparse, urlunparse
from django.conf import settings
from django.core.serializers.json import DjangoJSONEncoder
@@ -24,8 +23,6 @@ from awx.main.models import Job
from awx.main.access import access_registry
from awx.main.utils import get_awx_http_client_headers, set_environ, datetime_hook
from awx.main.utils.analytics_proxy import OIDCClient
from awx.main.utils.candlepin import get_or_generate_candlepin_certificate
from awx.main.utils.candlepin.client import _temp_cert_files
__all__ = ['register', 'gather', 'ship']
@@ -44,76 +41,6 @@ def _valid_license():
return True
def _get_cert_upload_url(url):
"""
Convert analytics URL to use 'cert.' subdomain for mTLS uploads.
Some analytics services use different hostnames for different auth methods:
- cert.example.com - for mTLS (certificate-based) uploads
- example.com - for OIDC (token-based) uploads
Args:
url: Original analytics URL
Returns:
URL with 'cert.' prepended to hostname if not already present
"""
try:
parsed = urlparse(url)
hostname = parsed.hostname
# Only modify if hostname doesn't already start with 'cert.'
if hostname and not hostname.startswith('cert.'):
new_hostname = f'cert.{hostname}'
# Reconstruct URL with new hostname
netloc = new_hostname
if parsed.port:
netloc = f'{new_hostname}:{parsed.port}'
new_parsed = parsed._replace(netloc=netloc)
return urlunparse(new_parsed)
return url
except Exception as e:
logger.warning(f'Could not modify URL for cert upload: {e}, using original URL')
return url
def _get_analytics_credentials():
"""
Get Red Hat Insights credentials from settings.
Attempts to retrieve credentials in the following priority order:
1. REDHAT_USERNAME / REDHAT_PASSWORD
2. SUBSCRIPTIONS_USERNAME / SUBSCRIPTIONS_PASSWORD
3. SUBSCRIPTIONS_CLIENT_ID / SUBSCRIPTIONS_CLIENT_SECRET
Returns:
tuple: (username, password) if credentials are found, (None, None) otherwise
"""
rh_id = getattr(settings, 'REDHAT_USERNAME', None)
rh_secret = getattr(settings, 'REDHAT_PASSWORD', None)
if rh_id and rh_secret:
return rh_id, rh_secret
# Try SUBSCRIPTIONS_USERNAME / SUBSCRIPTIONS_PASSWORD
rh_id = getattr(settings, 'SUBSCRIPTIONS_USERNAME', None)
rh_secret = getattr(settings, 'SUBSCRIPTIONS_PASSWORD', None)
if rh_id and rh_secret:
return rh_id, rh_secret
# Try SUBSCRIPTIONS_CLIENT_ID / SUBSCRIPTIONS_CLIENT_SECRET
rh_id = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None)
rh_secret = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None)
if rh_id and rh_secret:
return rh_id, rh_secret
return None, None
def all_collectors():
from awx.main.analytics import collectors
@@ -257,8 +184,10 @@ def gather(dest=None, module=None, subset=None, since=None, until=None, collecti
logger.log(log_level, "Automation Analytics not enabled. Use --dry-run to gather locally without sending.")
return None
rh_id, rh_secret = _get_analytics_credentials()
if not (settings.AUTOMATION_ANALYTICS_URL and rh_id and rh_secret):
if not (
settings.AUTOMATION_ANALYTICS_URL
and ((settings.REDHAT_USERNAME and settings.REDHAT_PASSWORD) or (settings.SUBSCRIPTIONS_CLIENT_ID and settings.SUBSCRIPTIONS_CLIENT_SECRET))
):
logger.log(log_level, "Not gathering analytics, configuration is invalid. Use --dry-run to gather locally without sending.")
return None
@@ -439,14 +368,19 @@ def ship(path):
logger.error('AUTOMATION_ANALYTICS_URL is not set')
return False
rh_id, rh_secret = _get_analytics_credentials()
rh_id = getattr(settings, 'REDHAT_USERNAME', None)
rh_secret = getattr(settings, 'REDHAT_PASSWORD', None)
if not (rh_id and rh_secret):
rh_id = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None)
rh_secret = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None)
if not rh_id:
logger.error('No valid username found. Tried: REDHAT_USERNAME, SUBSCRIPTIONS_USERNAME, SUBSCRIPTIONS_CLIENT_ID')
logger.error('Neither REDHAT_USERNAME nor SUBSCRIPTIONS_CLIENT_ID are set')
return False
if not rh_secret:
logger.error('No valid password found. Tried: REDHAT_PASSWORD, SUBSCRIPTIONS_PASSWORD, SUBSCRIPTIONS_CLIENT_SECRET')
logger.error('Neither REDHAT_PASSWORD nor SUBSCRIPTIONS_CLIENT_SECRET are set')
return False
with open(path, 'rb') as f:
@@ -454,40 +388,17 @@ def ship(path):
s = requests.Session()
s.headers = get_awx_http_client_headers()
s.headers.pop('Content-Type')
with set_environ(**settings.AWX_TASK_ENV):
# Try Certificate-based mTLS authentication (zero-touch)
cert_pem, key_pem = get_or_generate_candlepin_certificate()
if cert_pem and key_pem:
# Use cert. subdomain for mTLS uploads
cert_url = _get_cert_upload_url(url)
logger.debug("Attempting certificate-based authentication for analytics upload")
try:
with _temp_cert_files(cert_pem, key_pem) as (cert_path, key_path):
response = s.post(
cert_url, files=files, cert=(cert_path, key_path), verify=settings.INSIGHTS_CERT_PATH, headers=s.headers, timeout=(31, 31)
)
if response.status_code < 300:
return True
else:
logger.warning(
f'Certificate-based authentication failed with status {response.status_code}, {response.text}. Falling back to OIDC auth'
)
except Exception as e:
logger.warning(f"Certificate-based authentication failed: {e}, falling back to OIDC auth")
# Try OIDC authentication
logger.debug("Attempting OIDC authentication for analytics upload")
f.seek(0) # requests POST may read from the handler, so seek to beginning of file for the next POST attempt
try:
client = OIDCClient(rh_id, rh_secret)
response = client.make_request("POST", url, headers=s.headers, files=files, verify=settings.INSIGHTS_CERT_PATH, timeout=(31, 31))
except requests.RequestException:
logger.error("Automation Analytics API request failed, trying base auth method")
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_id, rh_secret), headers=s.headers, timeout=(31, 31))
if response.status_code < 300:
return True
else:
logger.error(f'OIDC authentication failed with status {response.status_code}, {response.text}')
return False
except requests.RequestException as e:
logger.error(f"OIDC authentication failed: {e}")
return False
# Accept 2XX status_codes
if response.status_code >= 300:
logger.error('Upload failed with status {}, {}'.format(response.status_code, response.text))
return False
return True

View File

@@ -213,40 +213,6 @@ register(
category_slug='system',
)
register(
'AWX_ANALYTICS_CANDLEPIN_CA',
field_class=fields.CharField,
default='/etc/rhsm/ca/redhat-uep.pem',
allow_blank=True,
label=_('Candlepin CA Certificate Path'),
help_text=_('Path to the CA certificate file for verifying TLS connections to Candlepin. Leave blank to use system certificates.'),
category=_('System'),
category_slug='system',
)
register(
'AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS',
field_class=fields.IntegerField,
default=90,
min_value=1,
label=_('Candlepin Certificate Renewal Threshold'),
help_text=_('Number of days before certificate expiry to trigger automatic renewal of Candlepin identity certificates.'),
category=_('System'),
category_slug='system',
unit=_('days'),
)
register(
'AWX_ANALYTICS_CANDLEPIN_PROXY_URL',
field_class=fields.CharField,
default='',
allow_blank=True,
label=_('Candlepin Proxy URL'),
help_text=_('HTTP/HTTPS proxy URL for Candlepin API requests (e.g., http://proxy.example.com:8080). Leave blank for no proxy.'),
category=_('System'),
category_slug='system',
)
register(
'INSTALL_UUID',
field_class=fields.CharField,
@@ -325,22 +291,6 @@ register(
category_slug='jobs',
)
register(
'INCLUDE_DEPRECATED_AWX_VAR_PREFIX',
field_class=fields.BooleanField,
default=True,
label=_('Include Deprecated AWX Variable Prefix'),
help_text=_(
'When enabled (default), auto-generated job variables are emitted '
'with both the tower_ prefix and the deprecated awx_ prefix for '
'backward compatibility. Disable to emit only tower_ prefixed '
'variables and eliminate duplicates. The awx_ prefix is deprecated '
'and this setting will default to False in a future release.'
),
category=_('Jobs'),
category_slug='jobs',
)
register(
'AWX_ISOLATION_BASE_PATH',
field_class=fields.CharField,
@@ -874,58 +824,6 @@ register(
unit=_('seconds'),
)
register(
'CANDLEPIN_CONSUMER_UUID',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=False,
label=_('Candlepin Consumer UUID'),
help_text=_('UUID of the registered Candlepin consumer for this AAP instance.'),
category=_('System'),
category_slug='system',
hidden=True,
)
register(
'CANDLEPIN_CERT_PEM',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=True,
label=_('Candlepin Identity Certificate'),
help_text=_('PEM-encoded Candlepin identity certificate for mTLS authentication.'),
category=_('System'),
category_slug='system',
hidden=True,
)
register(
'CANDLEPIN_KEY_PEM',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=True,
label=_('Candlepin Identity Key'),
help_text=_('PEM-encoded private key for Candlepin identity certificate.'),
category=_('System'),
category_slug='system',
hidden=True,
)
register(
'CANDLEPIN_SERIAL_NUMBER',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=False,
label=_('Candlepin Certificate Serial Number'),
help_text=_('Serial number of the Candlepin identity certificate for tracking.'),
category=_('System'),
category_slug='system',
hidden=True,
)
register(
'IS_K8S',
field_class=fields.BooleanField,

View File

@@ -11,7 +11,6 @@ __all__ = [
'CAN_CANCEL',
'ACTIVE_STATES',
'STANDARD_INVENTORY_UPDATE_ENV',
'OIDC_CREDENTIAL_TYPE_NAMESPACES',
]
PRIVILEGE_ESCALATION_METHODS = [
@@ -100,6 +99,10 @@ MAX_ISOLATED_PATH_COLON_DELIMITER = 2
SURVEY_TYPE_MAPPING = {'text': str, 'textarea': str, 'password': str, 'multiplechoice': str, 'multiselect': str, 'integer': int, 'float': (float, int)}
JOB_VARIABLE_PREFIXES = [
'awx',
'tower',
]
# Note, the \u001b[... are ansi color codes. We don't currenly import any of the python modules which define the codes.
# Importing a library just for this message seemed like overkill
@@ -137,6 +140,3 @@ org_role_to_permission = {
'execution_environment_admin_role': 'add_executionenvironment',
'auditor_role': 'view_project', # TODO: also doesnt really work
}
# OIDC credential type namespaces for feature flag filtering
OIDC_CREDENTIAL_TYPE_NAMESPACES = ['hashivault-kv-oidc', 'hashivault-ssh-oidc']

View File

@@ -27,11 +27,6 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
"pool_kwargs": {
"min_workers": settings.JOB_EVENT_WORKERS,
"max_workers": max_workers,
# This must be less than max_workers to make sense, which is usually 4
# With reserve of 1, after a burst of tasks, load needs to down to 4-1=3
# before we return to min_workers
"scaledown_reserve": 1,
"worker_max_lifetime_seconds": settings.WORKER_MAX_LIFETIME_SECONDS,
},
"main_kwargs": {"node_id": settings.CLUSTER_HOST_ID},
"process_manager_cls": "ForkServerManager",

View File

@@ -77,13 +77,13 @@ class CallbackBrokerWorker:
MAX_RETRIES = 2
INDIVIDUAL_EVENT_RETRIES = 3
last_stats = time.time()
last_flush = time.time()
total = 0
last_event = ''
prof = None
def __init__(self):
self.last_stats = time.time()
self.last_flush = time.time()
self.buff = {}
self.redis = get_redis_client()
self.subsystem_metrics = s_metrics.CallbackReceiverMetrics(auto_pipe_execute=False)

View File

@@ -428,9 +428,6 @@ class CredentialInputField(JSONSchemaField):
# determine the defined fields for the associated credential type
properties = {}
for field in model_instance.credential_type.inputs.get('fields', []):
# Prevent users from providing values for internally resolved fields
if 'internal' in field:
continue
field = field.copy()
properties[field['id']] = field
if field.get('choices', []):
@@ -569,7 +566,6 @@ class CredentialTypeInputField(JSONSchemaField):
},
'label': {'type': 'string'},
'help_text': {'type': 'string'},
'internal': {'type': 'boolean'},
'multiline': {'type': 'boolean'},
'secret': {'type': 'boolean'},
'ask_at_runtime': {'type': 'boolean'},

View File

@@ -1,330 +0,0 @@
import sys
from argparse import RawDescriptionHelpFormatter
from django.core.management.base import BaseCommand
from awx.main.utils.candlepin.client import CandlepinClient
from awx.main.utils.candlepin.lifecycle import (
get_candlepin_ca,
get_candlepin_url,
get_proxy_url,
get_renewal_days,
needs_renewal,
parse_cert,
)
from awx.main.utils.candlepin import (
_fetch_candlepin_cert_from_db,
_save_candlepin_cert_to_db,
_save_candlepin_registration_to_db,
resolve_registration_credentials,
)
class Command(BaseCommand):
"""
Manage Candlepin consumer registration and certificate lifecycle.
Subcommands:
register Register this AAP instance as a Candlepin consumer and obtain an
identity certificate for mTLS analytics uploads.
renew Perform a manual check-in and, if needed, renew the stored identity
certificate.
"""
help = 'Manage Candlepin consumer registration and certificate lifecycle'
def create_parser(self, prog_name, subcommand, **kwargs):
return super().create_parser(
prog_name,
subcommand,
formatter_class=RawDescriptionHelpFormatter,
epilog='\n'.join(
[
'SUBCOMMANDS',
'',
' register Register this instance as a Candlepin consumer.',
' Credentials are read from AWX database by default',
' (REDHAT_USERNAME, REDHAT_PASSWORD). The organization is',
' discovered automatically from the Candlepin account.',
' Pass --username / --password-stdin / --org to override.',
' Example: echo "password" | awx-manage candlepin_cert register --username user --password-stdin',
'',
' renew Perform a manual check-in and proactive cert renewal.',
' Reads the stored cert/key/UUID from database.',
' Use --force to renew even if the cert is not near expiry.',
'',
'CONFIGURATION',
'',
' Settings can be configured via Django settings (awx/settings/defaults.py):',
'',
' AWX_ANALYTICS_CANDLEPIN_URL Candlepin base URL',
' (default: https://subscription.example.com/candlepin)',
' AWX_ANALYTICS_CANDLEPIN_CA Path to Candlepin CA cert for TLS verification',
' AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS Days before expiry to trigger renewal (default: 90)',
' AWX_ANALYTICS_CANDLEPIN_PROXY_URL HTTP/HTTPS proxy for Candlepin API calls',
]
),
**kwargs,
)
def add_arguments(self, parser):
subparsers = parser.add_subparsers(dest='subcommand', metavar='subcommand')
subparsers.required = True
# --- register ---
reg = subparsers.add_parser(
'register',
help='Register this instance as a Candlepin consumer',
formatter_class=RawDescriptionHelpFormatter,
)
reg.add_argument('--username', help='Red Hat subscription username (overrides REDHAT_USERNAME from database)')
reg.add_argument(
'--password-stdin', dest='password_stdin', action='store_true', help='Read password from stdin (overrides REDHAT_PASSWORD from database)'
)
reg.add_argument('--org', help='Candlepin owner/org key (overrides auto-discovered organization)')
reg.add_argument('--candlepin-url', dest='candlepin_url', help='Candlepin base URL (overrides AWX_ANALYTICS_CANDLEPIN_URL setting)')
reg.add_argument(
'--candlepin-ca', dest='candlepin_ca', help='Path to Candlepin CA cert for TLS verification (overrides AWX_ANALYTICS_CANDLEPIN_CA setting)'
)
reg.add_argument('--proxy', help='HTTP/HTTPS proxy URL (overrides AWX_ANALYTICS_CANDLEPIN_PROXY_URL setting)')
reg.add_argument('--no-verify-tls', dest='no_verify_tls', action='store_true', help='Disable TLS certificate verification for Candlepin API calls')
reg.add_argument('--force', action='store_true', help='Re-register even if a certificate already exists in database')
reg.add_argument('--dry-run', dest='dry_run', action='store_true', help='Perform registration but do not save the result to database')
# --- renew ---
ren = subparsers.add_parser(
'renew',
help='Check in and renew the Candlepin identity certificate',
formatter_class=RawDescriptionHelpFormatter,
)
ren.add_argument('--candlepin-url', dest='candlepin_url', help='Candlepin base URL (overrides AWX_ANALYTICS_CANDLEPIN_URL setting)')
ren.add_argument(
'--candlepin-ca', dest='candlepin_ca', help='Path to Candlepin CA cert for TLS verification (overrides AWX_ANALYTICS_CANDLEPIN_CA setting)'
)
ren.add_argument('--proxy', help='HTTP/HTTPS proxy URL (overrides AWX_ANALYTICS_CANDLEPIN_PROXY_URL setting)')
ren.add_argument('--no-verify-tls', dest='no_verify_tls', action='store_true', help='Disable TLS certificate verification for Candlepin API calls')
ren.add_argument('--force', action='store_true', help='Renew the certificate even if it is not near expiry')
ren.add_argument('--dry-run', dest='dry_run', action='store_true', help='Perform check-in and renewal but do not save the result to database')
def handle(self, *args, **options):
subcommand = options['subcommand']
if subcommand == 'register':
ok = self._handle_register(options)
elif subcommand == 'renew':
ok = self._handle_renew(options)
else:
self.stderr.write(f'Unknown subcommand: {subcommand}')
sys.exit(1)
if not ok:
sys.exit(1)
# ------------------------------------------------------------------
# register
# ------------------------------------------------------------------
def _resolve_and_validate_credentials(self, options):
"""Merge CLI options with DB values and validate all required fields are present.
Returns ``(username, password, org, db_install_uuid)`` on success, or ``None``
if any required field is missing (errors are written to ``self.stderr``).
"""
username_override = options.get('username')
org_override = options.get('org')
verify_tls = not options.get('no_verify_tls', False)
# Read password from stdin if --password-stdin is set
if options.get('password_stdin'):
password_override = sys.stdin.read().strip()
if not password_override:
self.stderr.write('--password-stdin specified but no password provided on stdin')
return None
else:
password_override = None
# Use shared resolution and validation function
username, password, org, install_uuid, errors = resolve_registration_credentials(
username_override=username_override, password_override=password_override, org_override=org_override, verify_tls=verify_tls
)
if errors:
for error in errors:
self.stderr.write(f'Missing required value: {error}')
return None
return username, password, org, install_uuid
def _handle_register(self, options):
dry_run = options['dry_run']
force = options['force']
# Check whether a cert is already stored unless --force.
existing_cert, existing_key, _ = _fetch_candlepin_cert_from_db()
if existing_cert and existing_key and not force:
self.stdout.write('A Candlepin identity certificate is already stored in database. Use --force to re-register and replace it.')
return True
# Resolve credentials: CLI flags take precedence over database.
resolved = self._resolve_and_validate_credentials(options)
if resolved is None:
return False
username, password, org, db_install_uuid = resolved
candlepin_url = options.get('candlepin_url') or get_candlepin_url()
candlepin_ca = options.get('candlepin_ca') or get_candlepin_ca()
proxy = options.get('proxy') or get_proxy_url()
verify_tls = not options.get('no_verify_tls', False)
# If dry-run, display what would happen and exit early before any Candlepin operations
if dry_run:
self.stdout.write('[dry-run] Would register with Candlepin:')
self.stdout.write(f' URL : {candlepin_url}')
self.stdout.write(f' Organization : {org}')
self.stdout.write(f' Username : {username}')
self.stdout.write(f' Install UUID : {db_install_uuid}')
if candlepin_ca:
self.stdout.write(f' CA cert : {candlepin_ca}')
if proxy:
self.stdout.write(f' Proxy : {proxy}')
self.stdout.write(f' Verify TLS : {verify_tls}')
self.stdout.write('[dry-run] No Candlepin operations performed.')
return True
client = CandlepinClient(base_url=candlepin_url, candlepin_ca=candlepin_ca, proxy=proxy, verify_tls=verify_tls)
self.stdout.write(f'Registering with Candlepin at {candlepin_url} (org={org}) ...')
try:
cert_pem, key_pem, consumer_uuid = client.register_consumer(username, password, org, install_uuid=db_install_uuid)
except Exception as e:
self.stderr.write(f'Registration failed: {e}')
return False
self.stdout.write('Registered successfully.')
self.stdout.write(f' Consumer UUID : {consumer_uuid}')
# Save to database
if _save_candlepin_registration_to_db(cert_pem, key_pem, consumer_uuid):
self.stdout.write('Certificate, key, and consumer UUID saved to database.')
else:
self.stderr.write('Failed to save registration to database.')
return False
# Best-effort certificate metadata display
try:
info = parse_cert(cert_pem)
self.stdout.write(f' Cert serial : {info["serial"]}')
self.stdout.write(f' Cert CN : {info["cn"]}')
self.stdout.write(f' Valid until : {info["not_after"]} ({info["days_remaining"]} days remaining)')
except ValueError as e:
self.stdout.write(f'Certificate metadata unavailable: {e}')
return True
# ------------------------------------------------------------------
# renew
# ------------------------------------------------------------------
def _handle_renew(self, options):
dry_run = options['dry_run']
force = options['force']
cert_pem, key_pem, consumer_uuid = _fetch_candlepin_cert_from_db()
if not cert_pem or not key_pem:
self.stderr.write('No Candlepin identity certificate found in database. Run the register subcommand first.')
return False
if not consumer_uuid:
self.stderr.write('CANDLEPIN_CONSUMER_UUID is not set. Run the register subcommand first.')
return False
try:
info = parse_cert(cert_pem)
self.stdout.write('Current certificate:')
self.stdout.write(f' Serial : {info["serial"]}')
self.stdout.write(f' CN : {info["cn"]}')
self.stdout.write(f' Valid until : {info["not_after"]} ({info["days_remaining"]} days remaining)')
except ValueError as e:
self.stdout.write('Current certificate:')
self.stdout.write(f' Certificate metadata unavailable: {e}')
info = None
candlepin_url = options.get('candlepin_url') or get_candlepin_url()
candlepin_ca = options.get('candlepin_ca') or get_candlepin_ca()
proxy = options.get('proxy') or get_proxy_url()
verify_tls = not options.get('no_verify_tls', False)
renewal_days = get_renewal_days()
# Check if renewal is needed (without force, just check cert expiry locally)
renewal_needed = force or needs_renewal(cert_pem, renewal_days)
# If dry-run, display what would happen and exit early before any Candlepin operations
if dry_run:
self.stdout.write('[dry-run] Would perform the following operations:')
self.stdout.write(f' URL : {candlepin_url}')
self.stdout.write(f' Consumer UUID : {consumer_uuid}')
if candlepin_ca:
self.stdout.write(f' CA cert : {candlepin_ca}')
if proxy:
self.stdout.write(f' Proxy : {proxy}')
self.stdout.write(f' Verify TLS : {verify_tls}')
self.stdout.write(' 1. Check in with Candlepin')
if renewal_needed:
reason = 'forced via --force' if force else f'expiry within {renewal_days} days'
self.stdout.write(f' 2. Renew certificate ({reason})')
else:
if info:
self.stdout.write(f' 2. No renewal needed ({info["days_remaining"]} days remaining, threshold: {renewal_days} days)')
else:
self.stdout.write(f' 2. No renewal needed (threshold: {renewal_days} days)')
self.stdout.write('[dry-run] No Candlepin operations performed.')
return True
client = CandlepinClient(base_url=candlepin_url, candlepin_ca=candlepin_ca, proxy=proxy, verify_tls=verify_tls)
self.stdout.write(f'Checking in with Candlepin at {candlepin_url} (consumer={consumer_uuid}) ...')
checkin_success = client.checkin(consumer_uuid, cert_pem, key_pem)
if not checkin_success:
self.stderr.write('Check-in with Candlepin failed. Unable to verify certificate status.')
self.stderr.write('Certificate renewal may still be needed. Use --force to renew anyway, or check logs for details.')
return False
self.stdout.write('Check-in successful.')
if not renewal_needed:
if info:
self.stdout.write(f'Certificate has {info["days_remaining"]} days remaining (renewal threshold: {renewal_days} days). No renewal needed.')
else:
self.stdout.write(f'Certificate renewal threshold is {renewal_days} days. No renewal needed.')
return True
reason = 'forced via --force' if force else f'expiry within {renewal_days} days'
self.stdout.write(f'Renewing certificate ({reason}) ...')
try:
new_cert_pem, new_key_pem = client.regenerate_cert(consumer_uuid, cert_pem, key_pem)
except Exception as e:
self.stderr.write(f'Certificate renewal failed: {e}')
return False
self.stdout.write('Certificate renewed successfully.')
# Save to database
if _save_candlepin_cert_to_db(new_cert_pem, new_key_pem):
self.stdout.write('Renewed certificate and key saved to database.')
else:
self.stderr.write('Failed to save renewed certificate to database.')
return False
# Best-effort certificate metadata display
try:
new_info = parse_cert(new_cert_pem)
if info:
self.stdout.write(f' Old serial : {info["serial"]}')
self.stdout.write(f' New serial : {new_info["serial"]}')
self.stdout.write(f' Valid until : {new_info["not_after"]} ({new_info["days_remaining"]} days remaining)')
except ValueError as e:
self.stdout.write(f'Certificate metadata unavailable: {e}')
return True

View File

@@ -409,12 +409,10 @@ class Command(BaseCommand):
del_child_group_pks = list(set(db_children_name_pk_map.values()))
for offset in range(0, len(del_child_group_pks), self._batch_size):
child_group_pks = del_child_group_pks[offset : (offset + self._batch_size)]
children_to_remove = list(db_children.filter(pk__in=child_group_pks))
if children_to_remove:
group_group_count += len(children_to_remove)
db_group.children.remove(*children_to_remove)
for db_child in children_to_remove:
logger.debug('Group "%s" removed from group "%s"', db_child.name, db_group.name)
for db_child in db_children.filter(pk__in=child_group_pks):
group_group_count += 1
db_group.children.remove(db_child)
logger.debug('Group "%s" removed from group "%s"', db_child.name, db_group.name)
# FIXME: Inventory source group relationships
# Delete group/host relationships not present in imported data.
db_hosts = db_group.hosts
@@ -443,12 +441,12 @@ class Command(BaseCommand):
del_host_pks = list(del_host_pks)
for offset in range(0, len(del_host_pks), self._batch_size):
del_pks = del_host_pks[offset : (offset + self._batch_size)]
hosts_to_remove = list(db_hosts.filter(pk__in=del_pks))
if hosts_to_remove:
group_host_count += len(hosts_to_remove)
db_group.hosts.remove(*hosts_to_remove)
for db_host in hosts_to_remove:
logger.debug('Host "%s" removed from group "%s"', db_host.name, db_group.name)
for db_host in db_hosts.filter(pk__in=del_pks):
group_host_count += 1
if db_host not in db_group.hosts.all():
continue
db_group.hosts.remove(db_host)
logger.debug('Host "%s" removed from group "%s"', db_host.name, db_group.name)
if settings.SQL_DEBUG:
logger.warning(
'group-group and group-host deletions took %d queries for %d relationships',

View File

@@ -5,7 +5,6 @@ import logging
import uuid
from django.db import models
from django.conf import settings
from django.db.models import OuterRef, Subquery
from django.db.models.functions import Lower
from ansible_base.lib.utils.db import advisory_lock
@@ -24,65 +23,7 @@ class DeferJobCreatedManager(models.Manager):
return super(DeferJobCreatedManager, self).get_queryset().defer('job_created')
class HostLatestSummaryQuerySet(models.QuerySet):
"""Queryset that annotates and bulk-attaches the latest JobHostSummary
at queryset evaluation time, similar to prefetch_related().
Why not use Django's Prefetch?
Django's Prefetch with [:1] slicing fetches 1 record globally, not per-host
(Django ticket #26780). Window-function workarounds require Django 4.2+ and
are more complex. Prefetching all summaries then filtering in Python wastes
memory for hosts with many job runs. The approach here — annotate the latest
ID via Subquery, then in_bulk() only those IDs — is the same 2-query pattern
prefetch_related uses internally, customized for "latest per group."
Not streaming-safe: relies on _result_cache existing after _fetch_all().
"""
_awx_latest_summary_attached = False
def _clone(self):
clone = super()._clone()
clone._awx_latest_summary_attached = self._awx_latest_summary_attached
return clone
def with_latest_summary_id(self):
from awx.main.models.jobs import JobHostSummary
latest_summary = JobHostSummary.objects.filter(host_id=OuterRef('pk')).order_by('-id')
return self.annotate(
_latest_summary_id=Subquery(latest_summary.values('id')[:1]),
)
def _fetch_all(self):
super()._fetch_all()
if self._awx_latest_summary_attached or not self._result_cache:
return
# Only bulk-attach if the queryset was annotated via with_latest_summary_id().
# Without this guard, we'd set _latest_summary_cache=None on every host,
# masking the per-object fallback query in Host.latest_summary.
if not hasattr(self._result_cache[0], '_latest_summary_id'):
return
from awx.main.models.jobs import JobHostSummary
latest_summary_ids = [host._latest_summary_id for host in self._result_cache if host._latest_summary_id is not None]
if latest_summary_ids:
summaries_by_id = JobHostSummary.objects.select_related('job', 'job__job_template').in_bulk(latest_summary_ids)
else:
summaries_by_id = {}
for host in self._result_cache:
latest_summary_id = getattr(host, '_latest_summary_id', None)
host._latest_summary_cache = summaries_by_id.get(latest_summary_id)
self._awx_latest_summary_attached = True
class HostManager(models.Manager.from_queryset(HostLatestSummaryQuerySet)):
class HostManager(models.Manager):
"""Custom manager class for Hosts model."""
def active_count(self):
@@ -90,46 +31,38 @@ class HostManager(models.Manager.from_queryset(HostLatestSummaryQuerySet)):
Construction of query involves:
- remove any ordering specified in model's Meta
- Exclude hosts sourced from another Tower
- Exclude hosts in constructed inventories (these are shadow rows of source-inventory hosts)
- Restrict the query to only return the name column
- Only consider results that are unique
- Return the count of this query
"""
return (
self.order_by()
.exclude(inventory_sources__source='controller')
.exclude(inventory__kind='constructed')
.values(name_lower=Lower('name'))
.distinct()
.count()
)
return self.order_by().exclude(inventory_sources__source='controller').values(name_lower=Lower('name')).distinct().count()
def org_active_count(self, org_id):
"""Return count of active, unique hosts used by an organization.
Construction of query involves:
- remove any ordering specified in model's Meta
- Exclude hosts sourced from another Tower
- Exclude hosts in constructed inventories (these are shadow rows of source-inventory hosts)
- Consider only hosts where the canonical inventory is owned by the organization
- Restrict the query to only return the name column
- Only consider results that are unique
- Return the count of this query
"""
return (
self.order_by()
.exclude(inventory_sources__source='controller')
.exclude(inventory__kind='constructed')
.filter(inventory__organization=org_id)
.values('name')
.distinct()
.count()
)
return self.order_by().exclude(inventory_sources__source='controller').filter(inventory__organization=org_id).values('name').distinct().count()
def get_queryset(self):
"""When the parent instance of the host query set has a `kind=smart` and a `host_filter`
set. Use the `host_filter` to generate the queryset for the hosts.
"""
qs = super().get_queryset().defer('ansible_facts')
qs = (
super(HostManager, self)
.get_queryset()
.defer(
'last_job__extra_vars',
'last_job_host_summary__job__extra_vars',
'last_job__artifacts',
'last_job_host_summary__job__artifacts',
)
)
if hasattr(self, 'instance') and hasattr(self.instance, 'host_filter') and hasattr(self.instance, 'kind'):
if self.instance.kind == 'smart' and self.instance.host_filter is not None:

View File

@@ -1,29 +0,0 @@
# Generated by Django 5.2.8 on 2026-02-20 03:39
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('main', '0204_squashed_deletions'),
]
operations = [
migrations.AlterModelOptions(
name='instancegroup',
options={
'default_permissions': ('change', 'delete', 'view'),
'ordering': ('pk',),
'permissions': [('use_instancegroup', 'Can use instance group in a preference list of a resource')],
},
),
migrations.AlterModelOptions(
name='workflowjobnode',
options={'ordering': ('pk',)},
),
migrations.AlterModelOptions(
name='workflowjobtemplatenode',
options={'ordering': ('pk',)},
),
]

View File

@@ -28,7 +28,6 @@ from rest_framework.serializers import ValidationError as DRFValidationError
from ansible_base.lib.utils.db import advisory_lock
# AWX
from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES
from awx.api.versioning import reverse
from awx.main.fields import (
ImplicitRoleField,
@@ -49,6 +48,10 @@ from awx.main.models import Team, Organization
from awx.main.utils import encrypt_field
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name
# DAB
from ansible_base.resource_registry.tasks.sync import get_resource_server_client
from ansible_base.resource_registry.utils.settings import resource_server_defined
__all__ = ['Credential', 'CredentialType', 'CredentialInputSource', 'build_safe_env']
logger = logging.getLogger('awx.main.models.credential')
@@ -76,6 +79,46 @@ def build_safe_env(env):
return safe_env
def check_resource_server_for_user_in_organization(user, organization, requesting_user):
if not resource_server_defined():
return False
if not requesting_user:
return False
client = get_resource_server_client(settings.RESOURCE_SERVICE_PATH, jwt_user_id=str(requesting_user.resource.ansible_id), raise_if_bad_request=False)
# need to get the organization object_id in resource server, by querying with ansible_id
response = client._make_request(path=f'resources/?ansible_id={str(organization.resource.ansible_id)}', method='GET')
response_json = response.json()
if response.status_code != 200:
logger.error(f'Failed to get organization object_id in resource server: {response_json.get("detail", "")}')
return False
if response_json.get('count', 0) == 0:
return False
org_id_in_resource_server = response_json['results'][0]['object_id']
client.base_url = client.base_url.replace('/api/gateway/v1/service-index/', '/api/gateway/v1/')
# find role assignments with:
# - roles Organization Member or Organization Admin
# - user ansible id
# - organization object id
response = client._make_request(
path=f'role_user_assignments/?role_definition__name__in=Organization Member,Organization Admin&user__resource__ansible_id={str(user.resource.ansible_id)}&object_id={org_id_in_resource_server}',
method='GET',
)
response_json = response.json()
if response.status_code != 200:
logger.error(f'Failed to get role user assignments in resource server: {response_json.get("detail", "")}')
return False
if response_json.get('count', 0) > 0:
return True
return False
class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
"""
A credential contains information about how to talk to a remote resource
@@ -199,29 +242,6 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
needed.append('vault_password')
return needed
@functools.cached_property
def context(self):
"""
Property for storing runtime context during credential resolution.
The context is a dict keyed by CredentialInputSource PK, where each value
is a dict of runtime fields for that input source. Example::
{
<input_source_pk>: {
"workload_identity_token": "<jwt_token>"
},
<another_input_source_pk>: {
"workload_identity_token": "<different_jwt_token>"
},
}
This structure allows each input source to have its own set of runtime
values, avoiding conflicts when a credential has multiple input sources
with different configurations (e.g., different JWT audiences).
"""
return {}
@cached_property
def dynamic_input_fields(self):
# if the credential is not yet saved we can't access the input_sources
@@ -347,20 +367,21 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
def _get_dynamic_input(self, field_name):
for input_source in self.input_sources.all():
if input_source.input_field_name == field_name:
return input_source.get_input_value(context=self.context)
return input_source.get_input_value()
else:
raise ValueError('{} is not a dynamic input field'.format(field_name))
def validate_role_assignment(self, actor, role_definition, **kwargs):
requesting_user = kwargs.get('requesting_user', None)
if requesting_user and requesting_user.is_superuser:
return
if self.organization:
if isinstance(actor, User):
if actor.is_superuser:
return
if Organization.access_qs(actor, 'member').filter(id=self.organization.id).exists():
return
requesting_user = kwargs.get('requesting_user', None)
if check_resource_server_for_user_in_organization(actor, self.organization, requesting_user):
return
if isinstance(actor, Team):
if actor.organization == self.organization:
return
@@ -414,15 +435,13 @@ class CredentialType(CommonModelNameNotUnique):
def from_db(cls, db, field_names, values):
instance = super(CredentialType, cls).from_db(db, field_names, values)
if instance.managed and instance.namespace and instance.kind != "external":
native = ManagedCredentialType.registry.get(instance.namespace)
if native:
instance.inputs = native.inputs
instance.injectors = native.injectors
instance.custom_injectors = getattr(native, 'custom_injectors', None)
native = ManagedCredentialType.registry[instance.namespace]
instance.inputs = native.inputs
instance.injectors = native.injectors
instance.custom_injectors = getattr(native, 'custom_injectors', None)
elif instance.namespace and instance.kind == "external":
native = ManagedCredentialType.registry.get(instance.namespace)
if native:
instance.inputs = native.inputs
native = ManagedCredentialType.registry[instance.namespace]
instance.inputs = native.inputs
return instance
@@ -486,7 +505,6 @@ class CredentialType(CommonModelNameNotUnique):
existing = ct_class.objects.filter(name=default.name, kind=default.kind).first()
if existing is not None:
existing.namespace = default.namespace
existing.description = getattr(default, 'description', '')
existing.inputs = {}
existing.injectors = {}
existing.save()
@@ -526,14 +544,7 @@ class CredentialType(CommonModelNameNotUnique):
@classmethod
def load_plugin(cls, ns, plugin):
# TODO: User "side-loaded" credential custom_injectors isn't supported
ManagedCredentialType.registry[ns] = SimpleNamespace(
namespace=ns,
name=plugin.name,
kind='external',
inputs=plugin.inputs,
backend=plugin.backend,
description=getattr(plugin, 'plugin_description', ''),
)
ManagedCredentialType.registry[ns] = SimpleNamespace(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs, backend=plugin.backend)
def inject_credential(self, credential, env, safe_env, args, private_data_dir, container_root=None):
from awx_plugins.interfaces._temporary_private_inject_api import inject_credential
@@ -545,13 +556,7 @@ class CredentialTypeHelper:
@classmethod
def get_creation_params(cls, cred_type):
if cred_type.kind == 'external':
return {
'namespace': cred_type.namespace,
'kind': cred_type.kind,
'name': cred_type.name,
'managed': True,
'description': getattr(cred_type, 'description', ''),
}
return dict(namespace=cred_type.namespace, kind=cred_type.kind, name=cred_type.name, managed=True)
return dict(
namespace=cred_type.namespace,
kind=cred_type.kind,
@@ -617,15 +622,7 @@ class CredentialInputSource(PrimordialModel):
raise ValidationError(_('Input field must be defined on target credential (options are {}).'.format(', '.join(sorted(defined_fields)))))
return self.input_field_name
def get_input_value(self, context: dict | None = None):
"""
Retrieve the value from the external credential backend.
Args:
context: Optional runtime context dict passed from the target credential.
"""
if context is None:
context = {}
def get_input_value(self):
backend = self.source_credential.credential_type.plugin.backend
backend_kwargs = {}
for field_name, value in self.source_credential.inputs.items():
@@ -636,17 +633,6 @@ class CredentialInputSource(PrimordialModel):
backend_kwargs.update(self.metadata)
# Resolve internal fields from the per-input-source context.
# The context dict is keyed by input source PK, e.g.:
# {42: {"workload_identity_token": "eyJ..."}, 43: {"workload_identity_token": "eyX..."}}
# This allows each input source to carry its own runtime values.
input_source_context = context.get(self.pk, {})
for field in self.source_credential.credential_type.inputs.get('fields', []):
if field.get('internal'):
value = input_source_context.get(field['id'])
if value is not None:
backend_kwargs[field['id']] = value
with set_environ(**settings.AWX_TASK_ENV):
return backend(**backend_kwargs)
@@ -655,20 +641,13 @@ class CredentialInputSource(PrimordialModel):
return reverse(view_name, kwargs={'pk': self.pk}, request=request)
def _is_oidc_namespace_disabled(ns):
"""Check if a credential namespace should be skipped based on the OIDC feature flag."""
return ns in OIDC_CREDENTIAL_TYPE_NAMESPACES and not getattr(settings, 'FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED', False)
def load_credentials():
awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')}
supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')}
plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points}
for ns, ep in plugin_entry_points.items():
if _is_oidc_namespace_disabled(ns):
continue
cred_plugin = ep.load()
if not hasattr(cred_plugin, 'inputs'):
setattr(cred_plugin, 'inputs', {})
@@ -687,8 +666,5 @@ def load_credentials():
credential_plugins = {}
for ns, ep in credential_plugins.items():
if _is_oidc_namespace_disabled(ns):
continue
plugin = ep.load()
CredentialType.load_plugin(ns, plugin)

View File

@@ -24,6 +24,7 @@ from awx.main.managers import DeferJobCreatedManager
from awx.main.constants import MINIMAL_EVENTS
from awx.main.models.base import CreatedModifiedModel
from awx.main.utils import ignore_inventory_computed_fields, camelcase_to_underscore
from awx.main.utils.db import bulk_update_sorted_by_id
analytics_logger = logging.getLogger('awx.analytics.job_events')
@@ -589,8 +590,20 @@ class JobEvent(BasePlaybookEvent):
JobHostSummary.objects.bulk_create(summaries.values())
# last_job and last_job_host_summary are now derived via
# JobHostSummary.latest_for_host / latest_job_for_host
# update the last_job_id and last_job_host_summary_id
# in single queries
host_mapping = dict((summary['host_id'], summary['id']) for summary in JobHostSummary.objects.filter(job_id=job.id).values('id', 'host_id'))
updated_hosts = set()
for h in all_hosts:
# if the hostname *shows up* in the playbook_on_stats event
if h.name in hostnames:
h.last_job_id = job.id
updated_hosts.add(h)
if h.id in host_mapping:
h.last_job_host_summary_id = host_mapping[h.id]
updated_hosts.add(h)
bulk_update_sorted_by_id(Host, updated_hosts, ['last_job_id', 'last_job_host_summary_id'])
# Create/update Host Metrics
self._update_host_metrics(updated_hosts_list)

View File

@@ -58,6 +58,8 @@ class ExecutionEnvironment(CommonModel):
return reverse('api:execution_environment_detail', kwargs={'pk': self.pk}, request=request)
def validate_role_assignment(self, actor, role_definition, **kwargs):
from awx.main.models.credential import check_resource_server_for_user_in_organization
if self.managed:
raise ValidationError({'object_id': _('Can not assign object roles to managed Execution Environments')})
if self.organization_id is None:
@@ -67,4 +69,8 @@ class ExecutionEnvironment(CommonModel):
if actor.has_obj_perm(self.organization, 'view'):
return
requesting_user = kwargs.get('requesting_user', None)
if check_resource_server_for_user_in_organization(actor, self.organization, requesting_user):
return
raise ValidationError({'user': _('User must have view permission to Execution Environment organization')})

View File

@@ -485,7 +485,6 @@ class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin, ResourceMi
class Meta:
app_label = 'main'
ordering = ('pk',)
permissions = [('use_instancegroup', 'Can use instance group in a preference list of a resource')]
# Since this has no direct organization field only superuser can add, so remove add permission
default_permissions = ('change', 'delete', 'view')

View File

@@ -18,7 +18,7 @@ from django.db import transaction
from django.core.exceptions import ValidationError
from django.urls import resolve
from django.utils.timezone import now
from django.db.models import Q, Subquery, OuterRef
from django.db.models import Q
# REST Framework
from rest_framework.exceptions import ParseError
@@ -386,10 +386,7 @@ class Inventory(CommonModelNameNotUnique, ResourceMixin, RelatedJobsMixin, OpaQu
logger.debug("Going to update inventory computed fields, pk={0}".format(self.pk))
start_time = time.time()
active_hosts = self.hosts
from awx.main.models.jobs import JobHostSummary # circular import: inventory.py loads before jobs.py
latest_summary_failed = Subquery(JobHostSummary.objects.filter(host_id=OuterRef('pk')).order_by('-id').values('failed')[:1])
failed_hosts = active_hosts.annotate(_latest_failed=latest_summary_failed).filter(_latest_failed=True)
failed_hosts = active_hosts.filter(last_job_host_summary__failed=True)
active_groups = self.groups
if self.kind == 'smart':
active_groups = active_groups.none()
@@ -585,23 +582,6 @@ class Host(CommonModelNameNotUnique, RelatedJobsMixin):
objects = HostManager()
@property
def latest_summary(self):
if hasattr(self, '_latest_summary_cache'):
return self._latest_summary_cache
from awx.main.models.jobs import JobHostSummary
summary = JobHostSummary.objects.filter(host_id=self.pk).order_by('-id').select_related('job', 'job__job_template').first()
self._latest_summary_cache = summary
return summary
@property
def latest_job(self):
summary = self.latest_summary
if summary is None:
return None
return summary.job
def get_absolute_url(self, request=None):
return reverse('api:host_detail', kwargs={'pk': self.pk}, request=request)

View File

@@ -52,7 +52,7 @@ from awx.main.models.mixins import (
WebhookTemplateMixin,
OpaQueryPathMixin,
)
from awx.main.utils.common import get_job_variable_prefixes
from awx.main.constants import JOB_VARIABLE_PREFIXES
logger = logging.getLogger('awx.main.models.jobs')
@@ -817,20 +817,19 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
def awx_meta_vars(self):
r = super(Job, self).awx_meta_vars()
prefixes = get_job_variable_prefixes()
if self.project:
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_project_revision'.format(name)] = self.project.scm_revision
r['{}_project_scm_branch'.format(name)] = self.project.scm_branch
if self.scm_branch:
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_job_scm_branch'.format(name)] = self.scm_branch
if self.job_template:
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_job_template_id'.format(name)] = self.job_template.pk
r['{}_job_template_name'.format(name)] = self.job_template.name
if self.execution_node:
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_execution_node'.format(name)] = self.execution_node
return r
@@ -846,21 +845,6 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
def get_notification_friendly_name(self):
return "Job"
def get_source_hosts_for_constructed_inventory(self):
"""Return a QuerySet of the source (input inventory) hosts for a constructed inventory.
Constructed inventory hosts have an instance_id pointing to the real
host in the input inventory. This resolves those references and returns
a proper QuerySet (never a list), suitable for use with finish_fact_cache.
"""
Host = JobHostSummary._meta.get_field('host').related_model
if not self.inventory_id:
return Host.objects.none()
id_field = Host._meta.get_field('id')
return Host.objects.filter(id__in=self.inventory.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field))).only(
*HOST_FACTS_FIELDS
)
def get_hosts_for_fact_cache(self):
"""
Builds the queryset to use for writing or finalizing the fact cache
@@ -868,15 +852,17 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
For constructed inventories, that means the original (input inventory) hosts
when slicing, that means only returning hosts in that slice
"""
Host = JobHostSummary._meta.get_field('host').related_model
if not self.inventory_id:
Host = JobHostSummary._meta.get_field('host').related_model
return Host.objects.none()
if self.inventory.kind == 'constructed':
host_qs = self.get_source_hosts_for_constructed_inventory()
id_field = Host._meta.get_field('id')
host_qs = Host.objects.filter(id__in=self.inventory.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field)))
else:
host_qs = self.inventory.hosts.only(*HOST_FACTS_FIELDS)
host_qs = self.inventory.hosts
host_qs = host_qs.only(*HOST_FACTS_FIELDS)
host_qs = self.inventory.get_sliced_hosts(host_qs, self.job_slice_number, self.job_slice_count)
return host_qs
@@ -1141,22 +1127,6 @@ class JobHostSummary(CreatedModifiedModel):
self.skipped,
)
@classmethod
def latest_for_host(cls, host_id):
"""Return the most recent JobHostSummary for a given host, or None."""
return cls.objects.filter(host_id=host_id).order_by('-id').first()
@classmethod
def latest_job_for_host(cls, host_id):
"""Return the Job from the most recent JobHostSummary for a host, or None."""
summary = cls.latest_for_host(host_id)
if summary:
try:
return summary.job
except cls.job.field.related_model.DoesNotExist:
return None
return None
def get_absolute_url(self, request=None):
return reverse('api:job_host_summary_detail', kwargs={'pk': self.pk}, request=request)

View File

@@ -72,10 +72,10 @@ def _fast_forward_rrule(rrule, ref_dt=None):
if ref_dt is None:
ref_dt = now()
dtstart_tz = rrule._dtstart.tzinfo
ref_dt = ref_dt.astimezone(dtstart_tz)
ref_dt = ref_dt.astimezone(datetime.timezone.utc)
if rrule._dtstart > ref_dt:
rrule_dtstart_utc = rrule._dtstart.astimezone(datetime.timezone.utc)
if rrule_dtstart_utc > ref_dt:
return rrule
interval = rrule._interval if rrule._interval else 1
@@ -84,14 +84,20 @@ def _fast_forward_rrule(rrule, ref_dt=None):
elif rrule._freq == dateutil.rrule.MINUTELY:
interval *= 60
# if after converting to seconds the interval is still a fraction,
# just return original rrule
if isinstance(interval, float) and not interval.is_integer():
return rrule
seconds_since_dtstart = (ref_dt - rrule._dtstart).total_seconds()
seconds_since_dtstart = (ref_dt - rrule_dtstart_utc).total_seconds()
# it is important to fast forward by a number that is divisible by
# interval. For example, if interval is 7 hours, we fast forward by 7, 14, 21, etc. hours.
# Otherwise, the occurrences after the fast forward might not match the ones before.
# x // y is integer division, lopping off any remainder, so that we get the outcome we want.
interval_aligned_offset = datetime.timedelta(seconds=(seconds_since_dtstart // interval) * interval)
new_start = rrule._dtstart + interval_aligned_offset
new_rrule = rrule.replace(dtstart=new_start)
new_start = rrule_dtstart_utc + interval_aligned_offset
new_rrule = rrule.replace(dtstart=new_start.astimezone(rrule._dtstart.tzinfo))
return new_rrule

View File

@@ -58,8 +58,7 @@ from awx.main.utils.common import (
)
from awx.main.utils.encryption import encrypt_dict, decrypt_field
from awx.main.utils import polymorphic
from awx.main.constants import ACTIVE_STATES, CAN_CANCEL
from awx.main.utils.common import get_job_variable_prefixes
from awx.main.constants import ACTIVE_STATES, CAN_CANCEL, JOB_VARIABLE_PREFIXES
from awx.main.redact import UriCleaner, REPLACE_STR
from awx.main.consumers import emit_channel_notification
from awx.main.fields import AskForField, OrderedManyToManyField
@@ -1569,8 +1568,7 @@ class UnifiedJob(
by AWX, for purposes of client playbook hooks
"""
r = {}
prefixes = get_job_variable_prefixes()
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_job_id'.format(name)] = self.pk
r['{}_job_launch_type'.format(name)] = self.launch_type
@@ -1579,7 +1577,7 @@ class UnifiedJob(
wj = self.get_workflow_job()
if wj:
schedule = getattr_dne(wj, 'schedule')
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_workflow_job_id'.format(name)] = wj.pk
r['{}_workflow_job_name'.format(name)] = wj.name
r['{}_workflow_job_launch_type'.format(name)] = wj.launch_type
@@ -1590,12 +1588,12 @@ class UnifiedJob(
if not created_by:
schedule = getattr_dne(self, 'schedule')
if schedule:
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_schedule_id'.format(name)] = schedule.pk
r['{}_schedule_name'.format(name)] = schedule.name
if created_by:
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_user_id'.format(name)] = created_by.pk
r['{}_user_name'.format(name)] = created_by.username
r['{}_user_email'.format(name)] = created_by.email
@@ -1604,7 +1602,7 @@ class UnifiedJob(
inventory = getattr_dne(self, 'inventory')
if inventory:
for name in prefixes:
for name in JOB_VARIABLE_PREFIXES:
r['{}_inventory_id'.format(name)] = inventory.pk
r['{}_inventory_name'.format(name)] = inventory.name

View File

@@ -200,7 +200,6 @@ class WorkflowJobTemplateNode(WorkflowNodeBase):
indexes = [
models.Index(fields=['identifier']),
]
ordering = ('pk',)
def get_absolute_url(self, request=None):
return reverse('api:workflow_job_template_node_detail', kwargs={'pk': self.pk}, request=request)
@@ -287,7 +286,6 @@ class WorkflowJobNode(WorkflowNodeBase):
models.Index(fields=["identifier", "workflow_job"]),
models.Index(fields=['identifier']),
]
ordering = ('pk',)
@property
def event_processing_finished(self):
@@ -345,11 +343,7 @@ class WorkflowJobNode(WorkflowNodeBase):
)
data.update(accepted_fields) # missing fields are handled in the scheduler
# build ancestor artifacts, save them to node model for later
# initialize from pre-seeded ancestor_artifacts (set on root nodes of
# child workflows via seed_root_ancestor_artifacts to carry artifacts
# from the parent workflow); exclude job_slice which is internal
# metadata handled separately below
aa_dict = {k: v for k, v in self.ancestor_artifacts.items() if k != 'job_slice'} if self.ancestor_artifacts else {}
aa_dict = {}
is_root_node = True
for parent_node in self.get_parent_nodes():
is_root_node = False
@@ -370,13 +364,11 @@ class WorkflowJobNode(WorkflowNodeBase):
data['survey_passwords'] = password_dict
# process extra_vars
extra_vars = data.get('extra_vars', {})
if ujt_obj and isinstance(ujt_obj, JobTemplate):
if ujt_obj and isinstance(ujt_obj, (JobTemplate, WorkflowJobTemplate)):
if aa_dict:
functional_aa_dict = copy(aa_dict)
functional_aa_dict.pop('_ansible_no_log', None)
extra_vars.update(functional_aa_dict)
elif ujt_obj and isinstance(ujt_obj, WorkflowJobTemplate):
pass # artifacts are applied via seed_root_ancestor_artifacts in the task manager
# Workflow Job extra_vars higher precedence than ancestor artifacts
extra_vars.update(wj_special_vars)
@@ -740,18 +732,6 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio
wj = wj.get_workflow_job()
return ancestors
def seed_root_ancestor_artifacts(self, artifacts):
"""Apply parent workflow artifacts to root nodes so they propagate
through the normal ancestor_artifacts channel instead of being
baked into this workflow's extra_vars."""
self.workflow_job_nodes.exclude(
workflowjobnodes_success__isnull=False,
).exclude(
workflowjobnodes_failure__isnull=False,
).exclude(
workflowjobnodes_always__isnull=False,
).update(ancestor_artifacts=artifacts)
def get_effective_artifacts(self, **kwargs):
"""
For downstream jobs of a workflow nested inside of a workflow,
@@ -936,17 +916,6 @@ class WorkflowApproval(UnifiedJob, JobNotificationMixin):
ScheduleWorkflowManager().schedule()
return reverse('api:workflow_approval_deny', kwargs={'pk': self.pk}, request=request)
def cancel(self, job_explanation=None, is_chain=False):
# WorkflowApprovals have no dispatcher process (they wait for human
# input) and are excluded from TaskManager processing, so the base
# cancel() would only set cancel_flag without ever transitioning the
# status. We call super() for the flag, then transition directly.
has_already_canceled = bool(self.status == 'canceled')
super().cancel(job_explanation=job_explanation, is_chain=is_chain)
if self.status != 'canceled' and not has_already_canceled:
self.status = 'canceled'
self.save(update_fields=['status'])
def signal_start(self, **kwargs):
can_start = super(WorkflowApproval, self).signal_start(**kwargs)
self.started = self.created

View File

@@ -19,8 +19,13 @@ class ActivityStreamRegistrar(object):
pre_delete.connect(activity_stream_delete, sender=model, dispatch_uid=str(self.__class__) + str(model) + "_delete")
for m2mfield in model._meta.many_to_many:
m2m_attr = getattr(model, m2mfield.name)
m2m_changed.connect(activity_stream_associate, sender=m2m_attr.through, dispatch_uid=str(self.__class__) + str(m2m_attr.through) + "_associate")
try:
m2m_attr = getattr(model, m2mfield.name)
m2m_changed.connect(
activity_stream_associate, sender=m2m_attr.through, dispatch_uid=str(self.__class__) + str(m2m_attr.through) + "_associate"
)
except AttributeError:
pass
def disconnect(self, model):
if model in self.models:

View File

@@ -48,6 +48,11 @@ class SimpleDAG(object):
'''
self.node_to_edges_by_label = dict()
def __contains__(self, obj):
if self.node['node_object'] in self.node_obj_to_node_index:
return True
return False
def __len__(self):
return len(self.nodes)

View File

@@ -122,11 +122,8 @@ class WorkflowDAG(SimpleDAG):
if not job:
continue
elif job.can_cancel:
cancel_finished = False
job.cancel()
# If the job is not yet in a terminal state after .cancel(),
# the TaskManager still needs to process it.
if job.status not in ('successful', 'failed', 'canceled', 'error'):
cancel_finished = False
return cancel_finished
def is_workflow_done(self):

View File

@@ -196,10 +196,6 @@ class WorkflowManager(TaskBase):
workflow_job.start_args = '' # blank field to remove encrypted passwords
workflow_job.save(update_fields=['status', 'start_args'])
status_changed = True
else:
# Speed-up: schedule the task manager so it can process the
# canceled pending jobs without waiting for the next cycle.
ScheduleTaskManager().schedule()
else:
dnr_nodes = dag.mark_dnr_nodes()
WorkflowJobNode.objects.bulk_update(dnr_nodes, ['do_not_run'])
@@ -241,8 +237,6 @@ class WorkflowManager(TaskBase):
job = spawn_node.unified_job_template.create_unified_job(**kv)
spawn_node.job = job
spawn_node.save()
if spawn_node.ancestor_artifacts and isinstance(spawn_node.unified_job_template, WorkflowJobTemplate):
job.seed_root_ancestor_artifacts(spawn_node.ancestor_artifacts)
logger.debug('Spawned %s in %s for node %s', job.log_format, workflow_job.log_format, spawn_node.pk)
can_start = True
if isinstance(spawn_node.unified_job_template, WorkflowJobTemplate):
@@ -449,29 +443,17 @@ class TaskManager(TaskBase):
self.controlplane_ig = self.tm_models.instance_groups.controlplane_ig
def process_job_dep_failures(self, task):
"""If job depends on a job that has failed or been canceled, mark as failed.
Returns True if a dep failure was found, False otherwise.
"""
"""If job depends on a job that has failed, mark as failed and handle misc stuff."""
for dep in task.dependent_jobs.all():
# if we detect a failed, error, or canceled dependency, go ahead and fail this task.
if dep.status in ("error", "failed", "canceled"):
# if we detect a failed or error dependency, go ahead and fail this task.
if dep.status in ("error", "failed"):
task.status = 'failed'
if dep.status == 'canceled':
logger.warning(f'Previous task canceled, failing task: {task.id} dep: {dep.id} task manager')
task.job_explanation = 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
get_type_for_model(type(dep)),
dep.name,
dep.id,
)
ScheduleWorkflowManager().schedule() # speedup for dependency chains in workflow, on workflow cancel
else:
logger.warning(f'Previous task failed, failing task: {task.id} dep: {dep.id} task manager')
task.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
get_type_for_model(type(dep)),
dep.name,
dep.id,
)
logger.warning(f'Previous task failed task: {task.id} dep: {dep.id} task manager')
task.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
get_type_for_model(type(dep)),
dep.name,
dep.id,
)
task.save(update_fields=['status', 'job_explanation'])
task.websocket_emit_status('failed')
self.pre_start_failed.append(task.id)
@@ -563,17 +545,8 @@ class TaskManager(TaskBase):
logger.warning("Task manager has reached time out while processing pending jobs, exiting loop early")
break
if task.cancel_flag:
logger.debug(f"Canceling pending task {task.log_format} because cancel_flag is set")
task.status = 'canceled'
task.job_explanation = gettext_noop("This job was canceled before it started.")
task.save(update_fields=['status', 'job_explanation'])
task.websocket_emit_status('canceled')
self.pre_start_failed.append(task.id)
ScheduleWorkflowManager().schedule()
continue
if self.process_job_dep_failures(task):
has_failed = self.process_job_dep_failures(task)
if has_failed:
continue
blocked_by = self.job_blocked_by(task)

View File

@@ -36,6 +36,7 @@ from awx.main.models import (
Inventory,
InventorySource,
Job,
JobHostSummary,
Organization,
Project,
Role,
@@ -250,9 +251,45 @@ def migrate_children_from_deleted_group_to_parent_groups(sender, **kwargs):
pass
# Host.last_job and Host.last_job_host_summary are now derived from
# JobHostSummary.latest_for_host / latest_job_for_host.
# No signal handlers needed to maintain these denormalized FKs.
# Update host pointers to last_job and last_job_host_summary when a job is deleted
def _update_host_last_jhs(host):
jhs_qs = JobHostSummary.objects.filter(host__pk=host.pk)
try:
jhs = jhs_qs.order_by('-job__pk')[0]
except IndexError:
jhs = None
update_fields = []
try:
last_job = jhs.job if jhs else None
except Job.DoesNotExist:
# The job (and its summaries) have already been/are currently being
# deleted, so there's no need to update the host w/ a reference to it
return
if host.last_job != last_job:
host.last_job = last_job
update_fields.append('last_job')
if host.last_job_host_summary != jhs:
host.last_job_host_summary = jhs
update_fields.append('last_job_host_summary')
if update_fields:
host.save(update_fields=update_fields)
@receiver(pre_delete, sender=Job)
def save_host_pks_before_job_delete(sender, **kwargs):
instance = kwargs['instance']
hosts_qs = Host.objects.filter(last_job__pk=instance.pk)
instance._saved_hosts_pks = set(hosts_qs.values_list('pk', flat=True))
@receiver(post_delete, sender=Job)
def update_host_last_job_after_job_deleted(sender, **kwargs):
instance = kwargs['instance']
hosts_pks = getattr(instance, '_saved_hosts_pks', [])
for host in Host.objects.filter(pk__in=hosts_pks):
_update_host_last_jhs(host)
# Set via ActivityStreamRegistrar to record activity stream events

View File

@@ -54,6 +54,9 @@ def try_load_query_file(artifact_dir) -> Tuple[bool, Optional[dict]]:
returns the contents of ansible_data.json if present
"""
if not flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
return False, None
queries_path = os.path.join(artifact_dir, COLLECTION_FILENAME)
if not os.path.isfile(queries_path):
logger.info(f"no query file found: {queries_path}")
@@ -274,6 +277,20 @@ class RunnerCallback:
def artifacts_handler(self, artifact_dir):
success, query_file_contents = try_load_query_file(artifact_dir)
if success:
self.delay_update(event_queries_processed=False)
collections_info = collect_queries(query_file_contents)
for collection, data in collections_info.items():
version = data['version']
event_query = data['host_query']
instance = EventQuery(fqcn=collection, collection_version=version, event_query=event_query)
try:
instance.validate_unique()
instance.save()
logger.info(f"eventy query for collection {collection}, version {version} created")
except ValidationError as e:
logger.info(e)
if 'installed_collections' in query_file_contents:
self.delay_update(installed_collections=query_file_contents['installed_collections'])
else:
@@ -284,21 +301,6 @@ class RunnerCallback:
else:
logger.warning(f'The file {COLLECTION_FILENAME} unexpectedly did not contain ansible_version')
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
self.delay_update(event_queries_processed=False)
collections_info = collect_queries(query_file_contents)
for collection, data in collections_info.items():
version = data['version']
event_query = data['host_query']
instance = EventQuery(fqcn=collection, collection_version=version, event_query=event_query)
try:
instance.validate_unique()
instance.save()
logger.info(f"event query for collection {collection}, version {version} created")
except ValidationError as e:
logger.info(e)
self.artifacts_processed = True

View File

@@ -25,8 +25,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
log_data = log_data or {}
log_data['inventory_id'] = inventory_id
log_data['written_ct'] = 0
# Dict mapping host name -> bool (True if a fact file was written)
hosts_cached = {}
hosts_cached = []
# Create the fact_cache directory inside artifacts_dir
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
@@ -38,14 +37,13 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
last_write_time = None
for host in hosts:
hosts_cached.append(host.name)
if not host.ansible_facts_modified or (timeout and host.ansible_facts_modified < now() - datetime.timedelta(seconds=timeout)):
hosts_cached[host.name] = False
continue # facts are expired - do not write them
filepath = os.path.join(fact_cache_dir, host.name)
if not os.path.realpath(filepath).startswith(fact_cache_dir):
logger.error(f'facts for host {smart_str(host.name)} could not be cached')
hosts_cached[host.name] = False
continue
try:
@@ -53,18 +51,9 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
os.chmod(f.name, 0o600)
json.dump(host.ansible_facts, f)
log_data['written_ct'] += 1
# Backdate the file by 2 seconds so finish_fact_cache can reliably
# distinguish these reference files from files updated by ansible.
# This guarantees fact file mtime < summary file mtime even with
# zipfile's 2-second timestamp rounding during artifact transfer.
mtime = os.path.getmtime(filepath)
backdated = mtime - 2
os.utime(filepath, (backdated, backdated))
last_write_time = backdated
hosts_cached[host.name] = True
last_write_time = os.path.getmtime(filepath)
except IOError:
logger.error(f'facts for host {smart_str(host.name)} could not be cached')
hosts_cached[host.name] = False
continue
# Write summary file directly to the artifacts_dir
@@ -73,6 +62,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
summary_data = {
'last_write_time': last_write_time,
'hosts_cached': hosts_cached,
'written_ct': log_data['written_ct'],
}
with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(summary_data, f, indent=2)
@@ -84,7 +74,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
msg='Inventory {inventory_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s',
add_log_data=True,
)
def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, job_created=None, log_data=None):
def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=None):
log_data = log_data or {}
log_data['inventory_id'] = inventory_id
log_data['updated_ct'] = 0
@@ -99,118 +89,63 @@ def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, jo
try:
with open(summary_path, 'r', encoding='utf-8') as f:
summary = json.load(f)
facts_write_time = os.path.getmtime(summary_path)
facts_write_time = os.path.getmtime(summary_path) # After successful read
except (json.JSONDecodeError, OSError) as e:
logger.error(f'Error reading summary file at {summary_path}: {e}')
return
hosts_cached_map = summary.get('hosts_cached', {})
host_names = summary.get('hosts_cached', [])
hosts_cached = Host.objects.filter(name__in=host_names).order_by('id').iterator()
# Path where individual fact files were written
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
hosts_to_update = []
# Phase 1: Scan files on disk to discover which hosts have updated or missing facts
hosts_with_updates = set() # hostnames whose fact file was modified by Ansible
hosts_to_clear = [] # hostnames where Ansible removed the fact file
seen_in_dir = set() # hostnames we found as files on disk
for host in hosts_cached:
filepath = os.path.join(fact_cache_dir, host.name)
if not os.path.realpath(filepath).startswith(fact_cache_dir):
logger.error(f'Invalid path for facts file: {filepath}')
continue
if os.path.isdir(fact_cache_dir):
for filename in os.listdir(fact_cache_dir):
if filename not in hosts_cached_map:
continue # not an expected host for this job
if os.path.exists(filepath):
# If the file changed since we wrote the last facts file, pre-playbook run...
modified = os.path.getmtime(filepath)
if not facts_write_time or modified >= facts_write_time:
try:
with codecs.open(filepath, 'r', encoding='utf-8') as f:
ansible_facts = json.load(f)
except ValueError:
continue
filepath = os.path.join(fact_cache_dir, filename)
if os.path.islink(filepath):
logger.error(f'Invalid path for facts file: {filepath}')
continue
if not os.path.isfile(filepath):
continue
seen_in_dir.add(filename)
try:
modified = os.path.getmtime(filepath)
except OSError as e:
logger.warning(f'Could not stat facts file {filepath}: {e}')
continue
if modified >= facts_write_time:
hosts_with_updates.add(filename)
if ansible_facts != host.ansible_facts:
host.ansible_facts = ansible_facts
host.ansible_facts_modified = now()
hosts_to_update.append(host)
logger.info(
f'New fact for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}',
extra=dict(
inventory_id=host.inventory.id,
host_name=host.name,
ansible_facts=host.ansible_facts,
ansible_facts_modified=host.ansible_facts_modified.isoformat(),
job_id=job_id,
),
)
log_data['updated_ct'] += 1
else:
log_data['unmodified_ct'] += 1
else:
log_data['unmodified_ct'] += 1
# Check for files we wrote pre-job that are now missing (Ansible cleared facts)
for hostname, was_written in hosts_cached_map.items():
if hostname in seen_in_dir:
continue # already handled above
if was_written:
hosts_to_clear.append(hostname)
else:
log_data['unmodified_ct'] += 1
# if the file goes missing, ansible removed it (likely via clear_facts)
# if the file goes missing, but the host has not started facts, then we should not clear the facts
host.ansible_facts = {}
host.ansible_facts_modified = now()
hosts_to_update.append(host)
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
log_data['cleared_ct'] += 1
# Phase 2: Stream updated facts to database in batches
if hosts_with_updates:
hosts_to_save = []
total_rows_updated = 0
for host in host_qs.filter(name__in=list(hosts_with_updates)).select_related('inventory').iterator():
filepath = os.path.join(fact_cache_dir, host.name)
try:
with codecs.open(filepath, 'r', encoding='utf-8') as f:
new_facts = json.load(f)
except (ValueError, OSError):
continue
if len(hosts_to_update) >= 100:
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
hosts_to_update = []
if new_facts != host.ansible_facts:
host.ansible_facts = new_facts
host.ansible_facts_modified = now()
hosts_to_save.append(host)
logger.info(
f'New fact for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}',
extra=dict(
inventory_id=host.inventory.id,
host_name=host.name,
ansible_facts=host.ansible_facts,
ansible_facts_modified=host.ansible_facts_modified.isoformat(),
job_id=job_id,
),
)
log_data['updated_ct'] += 1
else:
log_data['unmodified_ct'] += 1
if len(hosts_to_save) >= 100:
total_rows_updated += bulk_update_sorted_by_id(Host, hosts_to_save, fields=['ansible_facts', 'ansible_facts_modified'])
hosts_to_save = []
if hosts_to_save:
total_rows_updated += bulk_update_sorted_by_id(Host, hosts_to_save, fields=['ansible_facts', 'ansible_facts_modified'])
# Mismatch means a concurrent process changed or deleted hosts between our read and bulk update
if total_rows_updated != log_data['updated_ct']:
logger.warning(
f'Fact update for inventory {inventory_id} job {job_id}: expected to update {log_data["updated_ct"]} hosts but {total_rows_updated} rows were changed'
)
# Phase 3: Clear facts for hosts whose files were removed by Ansible
if hosts_to_clear:
hosts = list(host_qs.filter(name__in=hosts_to_clear).select_related('inventory'))
clear_hosts = []
for host in hosts:
if job_created and host.ansible_facts_modified and host.ansible_facts_modified > job_created:
logger.warning(
f'Skipping fact clear for host {smart_str(host.name)} in job {job_id} '
f'inventory {inventory_id}: host ansible_facts_modified '
f'({host.ansible_facts_modified.isoformat()}) is after this job\'s '
f'created time ({job_created.isoformat()}). '
f'A concurrent job likely updated this host\'s facts while this job was running.'
)
log_data['unmodified_ct'] += 1
else:
host.ansible_facts = {}
host.ansible_facts_modified = now()
clear_hosts.append(host)
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
log_data['cleared_ct'] += 1
if clear_hosts:
rows = bulk_update_sorted_by_id(Host, clear_hosts, fields=['ansible_facts', 'ansible_facts_modified'])
if rows != len(clear_hosts):
logger.warning(f'Fact clear for inventory {inventory_id} job {job_id}: expected to clear {len(clear_hosts)} hosts but {rows} rows were changed')
logger.debug(f'Updated {log_data["updated_ct"]} host facts for inventory {inventory_id} in job {job_id}')
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])

View File

@@ -17,6 +17,7 @@ import urllib.parse as urlparse
# Django
from django.conf import settings
from django.db import transaction
# Shared code for the AWX platform
from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT, get_incontainer_path
@@ -94,7 +95,6 @@ from flags.state import flag_enabled
# Workload Identity
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims
logger = logging.getLogger('awx.main.tasks.jobs')
@@ -104,6 +104,11 @@ def populate_claims_for_workload(unified_job) -> dict:
Extract JWT claims from a Controller workload for the aap_controller_automation_job scope.
"""
# Related objects in the UnifiedJob model, applies to all job types
organization = getattr_dne(unified_job, 'organization')
ujt = getattr_dne(unified_job, 'unified_job_template')
instance_group = getattr_dne(unified_job, 'instance_group')
claims = {
AutomationControllerJobScope.CLAIM_JOB_ID: unified_job.id,
AutomationControllerJobScope.CLAIM_JOB_NAME: unified_job.name,
@@ -158,24 +163,6 @@ def populate_claims_for_workload(unified_job) -> dict:
return claims
def retrieve_workload_identity_jwt(
unified_job: UnifiedJob,
audience: str,
scope: str,
workload_ttl_seconds: int | None = None,
) -> str:
"""Retrieve JWT token from workload claims.
Raises:
RuntimeError: if the workload identity client is not configured.
"""
return retrieve_workload_identity_jwt_with_claims(
populate_claims_for_workload(unified_job),
audience,
scope,
workload_ttl_seconds,
)
def with_path_cleanup(f):
@functools.wraps(f)
def _wrapped(self, *args, **kwargs):
@@ -202,7 +189,6 @@ def dispatch_waiting_jobs(binder):
if not kwargs:
kwargs = {}
binder.control('run', data={'task': serialize_task(uj._get_task_class()), 'args': [uj.id], 'kwargs': kwargs, 'uuid': uj.celery_task_id})
UnifiedJob.objects.filter(pk=uj.pk, status='waiting').update(status='running', start_args='')
class BaseTask(object):
@@ -217,63 +203,6 @@ class BaseTask(object):
self.update_attempts = int(getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE) / 5)
self.runner_callback = self.callback_class(model=self.model)
@functools.cached_property
def _credentials(self):
"""
Credentials for the task execution.
Fetches credentials once using build_credentials_list() and stores
them for the duration of the task to avoid redundant database queries.
"""
credentials_list = self.build_credentials_list(self.instance)
# Convert to list to prevent re-evaluation of QuerySet
return list(credentials_list)
def populate_workload_identity_tokens(self, additional_credentials=None):
"""
Populate credentials with workload identity tokens.
Sets the context on Credential objects that have input sources
using compatible external credential types.
"""
credentials = list(self._credentials)
if additional_credentials:
credentials.extend(additional_credentials)
credential_input_sources = (
(credential.context, src)
for credential in credentials
for src in credential.input_sources.all()
if any(
field.get('id') == 'workload_identity_token' and field.get('internal')
for field in src.source_credential.credential_type.inputs.get('fields', [])
)
)
for credential_ctx, input_src in credential_input_sources:
if flag_enabled("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED"):
effective_timeout = self.get_instance_timeout(self.instance)
workload_ttl = effective_timeout if effective_timeout else None
try:
jwt = retrieve_workload_identity_jwt(
self.instance,
audience=input_src.source_credential.get_input('url'),
scope=AutomationControllerJobScope.name,
workload_ttl_seconds=workload_ttl,
)
# Store token keyed by input source PK, since a credential can have
# multiple input sources (one per field), each potentially with a different audience
credential_ctx[input_src.pk] = {"workload_identity_token": jwt}
except Exception as e:
self.instance.job_explanation = (
f'Could not generate workload identity token for credential {input_src.source_credential.name} used in this job. Error:\n{e}'
)
self.instance.status = 'error'
self.instance.save()
else:
self.instance.job_explanation = (
f'Flag FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is not enabled, required for credential {input_src.source_credential.name} used in this job.'
)
self.instance.status = 'error'
self.instance.save()
def update_model(self, pk, _attempt=0, **updates):
return update_model(self.model, pk, _attempt=0, _max_attempts=self.update_attempts, **updates)
@@ -425,19 +354,6 @@ class BaseTask(object):
private_data_files['credentials'][credential] = self.write_private_data_file(private_data_dir, None, data, sub_dir='env')
for credential, data in private_data.get('certificates', {}).items():
self.write_private_data_file(private_data_dir, 'ssh_key_data-cert.pub', data, sub_dir=os.path.join('artifacts', str(self.instance.id)))
# Copy vendor collections to private_data_dir for indirect node counting
# This makes external query files available to the callback plugin in EEs
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
vendor_src = '/var/lib/awx/vendor_collections'
vendor_dest = os.path.join(private_data_dir, 'vendor_collections')
if os.path.exists(vendor_src):
try:
shutil.copytree(vendor_src, vendor_dest)
logger.debug(f"Copied vendor collections from {vendor_src} to {vendor_dest}")
except Exception as e:
logger.warning(f"Failed to copy vendor collections: {e}")
return private_data_files, ssh_key_data
def build_passwords(self, instance, runtime_passwords):
@@ -511,7 +427,6 @@ class BaseTask(object):
return []
def get_instance_timeout(self, instance):
"""Return the effective job timeout in seconds."""
global_timeout_setting_name = instance._global_timeout_setting()
if global_timeout_setting_name:
global_timeout = getattr(settings, global_timeout_setting_name, 0)
@@ -620,32 +535,48 @@ class BaseTask(object):
def should_use_fact_cache(self):
return False
def transition_status(self, pk: int) -> bool:
"""Atomically transition status to running, if False returned, another process got it"""
with transaction.atomic():
# Explanation of parts for the fetch:
# .values - avoid loading a full object, this is known to lead to deadlocks due to signals
# the signals load other related rows which another process may be locking, and happens in practice
# of=('self',) - keeps FK tables out of the lock list, another way deadlocks can happen
# .get - just load the single job
instance_data = UnifiedJob.objects.select_for_update(of=('self',)).values('status', 'cancel_flag').get(pk=pk)
# If status is not waiting (obtained under lock) then this process does not have clearence to run
if instance_data['status'] == 'waiting':
if instance_data['cancel_flag']:
updated_status = 'canceled'
else:
updated_status = 'running'
# Explanation of the update:
# .filter - again, do not load the full object
# .update - a bulk update on just that one row, avoid loading unintended data
UnifiedJob.objects.filter(pk=pk).update(status=updated_status, start_args='')
elif instance_data['status'] == 'running':
logger.info(f'Job {pk} is being ran by another process, exiting')
return False
return True
@with_path_cleanup
@with_signal_handling
def run(self, pk, **kwargs):
"""
Run the job/task and capture its output.
"""
if not self.instance: # Used to skip fetch for local runs
# Load the instance
self.instance = self.update_model(pk)
# status should be "running" from dispatch_waiting_jobs,
# but may still be "waiting" if the worker picked this up before the status update landed.
if self.instance.status == 'waiting':
UnifiedJob.objects.filter(pk=pk).update(status="running", start_args='')
self.instance.refresh_from_db()
if not self.transition_status(pk):
logger.info(f'Job {pk} is being ran by another process, exiting')
return
# Load the instance
self.instance = self.update_model(pk)
if self.instance.status != 'running':
logger.error(f'Not starting {self.instance.status} task pk={pk} because its status "{self.instance.status}" is not expected')
return
if self.instance.cancel_flag:
self.instance = self.update_model(pk, status='canceled')
self.instance.websocket_emit_status('canceled')
return
self.instance.websocket_emit_status("running")
status, rc = 'error', None
self.runner_callback.event_ct = 0
@@ -684,12 +615,6 @@ class BaseTask(object):
if not os.path.exists(settings.AWX_ISOLATION_BASE_PATH):
raise RuntimeError('AWX_ISOLATION_BASE_PATH=%s does not exist' % settings.AWX_ISOLATION_BASE_PATH)
if flag_enabled("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED"):
logger.info(f'Generating workload identity tokens for {self.instance.log_format}')
self.populate_workload_identity_tokens()
if self.instance.status == 'error':
raise RuntimeError('not starting %s task' % self.instance.status)
# May have to serialize the value
private_data_files, ssh_key_data = self.build_private_data_files(self.instance, private_data_dir)
passwords = self.build_passwords(self.instance, kwargs)
@@ -707,7 +632,7 @@ class BaseTask(object):
self.runner_callback.job_created = str(self.instance.created)
credentials = self._credentials
credentials = self.build_credentials_list(self.instance)
container_root = None
if settings.IS_K8S and isinstance(self.instance, ProjectUpdate):
@@ -1002,29 +927,6 @@ class RunJob(SourceControlMixin, BaseTask):
model = Job
event_model = JobEvent
def _extract_credentials_of_kind(self, kind: str):
return (cred for cred in self._credentials if cred.credential_type.kind == kind)
@property
def _machine_credential(self) -> object:
"""Get machine credential."""
return next(self._extract_credentials_of_kind('ssh'), None)
@property
def _vault_credentials(self) -> list[object]:
"""Get vault credentials."""
return list(self._extract_credentials_of_kind('vault'))
@property
def _network_credentials(self) -> list[object]:
"""Get network credentials."""
return list(self._extract_credentials_of_kind('net'))
@property
def _cloud_credentials(self) -> list[object]:
"""Get cloud credentials."""
return list(self._extract_credentials_of_kind('cloud'))
def build_private_data(self, job, private_data_dir):
"""
Returns a dict of the form
@@ -1042,7 +944,7 @@ class RunJob(SourceControlMixin, BaseTask):
}
"""
private_data = {'credentials': {}}
for credential in self._credentials:
for credential in job.credentials.prefetch_related('input_sources__source_credential').all():
# If we were sent SSH credentials, decrypt them and send them
# back (they will be written to a temporary file).
if credential.has_input('ssh_key_data'):
@@ -1058,14 +960,14 @@ class RunJob(SourceControlMixin, BaseTask):
and ansible-vault.
"""
passwords = super(RunJob, self).build_passwords(job, runtime_passwords)
cred = self._machine_credential
cred = job.machine_credential
if cred:
for field in ('ssh_key_unlock', 'ssh_password', 'become_password', 'vault_password'):
value = runtime_passwords.get(field, cred.get_input('password' if field == 'ssh_password' else field, default=''))
if value not in ('', 'ASK'):
passwords[field] = value
for cred in self._vault_credentials:
for cred in job.vault_credentials:
field = 'vault_password'
vault_id = cred.get_input('vault_id', default=None)
if vault_id:
@@ -1081,7 +983,7 @@ class RunJob(SourceControlMixin, BaseTask):
key unlock over network key unlock.
'''
if 'ssh_key_unlock' not in passwords:
for cred in self._network_credentials:
for cred in job.network_credentials:
if cred.inputs.get('ssh_key_unlock'):
passwords['ssh_key_unlock'] = runtime_passwords.get('ssh_key_unlock', cred.get_input('ssh_key_unlock', default=''))
break
@@ -1116,11 +1018,11 @@ class RunJob(SourceControlMixin, BaseTask):
# Set environment variables for cloud credentials.
cred_files = private_data_files.get('credentials', {})
for cloud_cred in self._cloud_credentials:
for cloud_cred in job.cloud_credentials:
if cloud_cred and cloud_cred.credential_type.namespace == 'openstack' and cred_files.get(cloud_cred, ''):
env['OS_CLIENT_CONFIG_FILE'] = get_incontainer_path(cred_files.get(cloud_cred, ''), private_data_dir)
for network_cred in self._network_credentials:
for network_cred in job.network_credentials:
env['ANSIBLE_NET_USERNAME'] = network_cred.get_input('username', default='')
env['ANSIBLE_NET_PASSWORD'] = network_cred.get_input('password', default='')
@@ -1138,11 +1040,12 @@ class RunJob(SourceControlMixin, BaseTask):
('ANSIBLE_COLLECTIONS_PATH', 'collections_path', 'requirements_collections', '~/.ansible/collections:/usr/share/ansible/collections'),
]
path_vars.append(
('ANSIBLE_CALLBACK_PLUGINS', 'callback_plugins', 'plugins_path', '~/.ansible/plugins:/plugins/callback:/usr/share/ansible/plugins/callback'),
)
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
path_vars.append(
('ANSIBLE_CALLBACK_PLUGINS', 'callback_plugins', 'plugins_path', '~/.ansible/plugins:/plugins/callback:/usr/share/ansible/plugins/callback'),
)
config_values = read_ansible_config(os.path.join(private_data_dir, 'project'), list(map(lambda x: x[1], path_vars)) + ['callbacks_enabled'])
config_values = read_ansible_config(os.path.join(private_data_dir, 'project'), list(map(lambda x: x[1], path_vars)))
for env_key, config_setting, folder, default in path_vars:
paths = default.split(':')
@@ -1157,16 +1060,10 @@ class RunJob(SourceControlMixin, BaseTask):
paths = [os.path.join(CONTAINER_ROOT, folder)] + paths
env[env_key] = os.pathsep.join(paths)
env['ANSIBLE_CALLBACKS_ENABLED'] = 'indirect_instance_count'
if 'callbacks_enabled' in config_values:
env['ANSIBLE_CALLBACKS_ENABLED'] += ',' + config_values['callbacks_enabled']
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
env['AWX_COLLECT_HOST_QUERIES'] = '1'
# Add vendor collections path for external query file discovery
vendor_collections_path = os.path.join(CONTAINER_ROOT, 'vendor_collections')
env['ANSIBLE_COLLECTIONS_PATH'] = f"{vendor_collections_path}:{env['ANSIBLE_COLLECTIONS_PATH']}"
logger.debug(f"ANSIBLE_COLLECTIONS_PATH updated for vendor collections: {env['ANSIBLE_COLLECTIONS_PATH']}")
env['ANSIBLE_CALLBACKS_ENABLED'] = 'indirect_instance_count'
if 'callbacks_enabled' in config_values:
env['ANSIBLE_CALLBACKS_ENABLED'] += ':' + config_values['callbacks_enabled']
return env
@@ -1175,7 +1072,7 @@ class RunJob(SourceControlMixin, BaseTask):
Build command line argument list for running ansible-playbook,
optionally using ssh-agent for public/private key authentication.
"""
creds = self._machine_credential
creds = job.machine_credential
ssh_username, become_username, become_method = '', '', ''
if creds:
@@ -1327,17 +1224,10 @@ class RunJob(SourceControlMixin, BaseTask):
return
if self.should_use_fact_cache() and self.runner_callback.artifacts_processed:
job.log_lifecycle("finish_job_fact_cache")
if job.inventory.kind == 'constructed':
hosts_qs = job.get_source_hosts_for_constructed_inventory()
else:
hosts_qs = job.inventory.hosts
hosts_qs = hosts_qs.only(*HOST_FACTS_FIELDS)
finish_fact_cache(
hosts_qs,
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)),
job_id=job.id,
inventory_id=job.inventory_id,
job_created=job.created,
)
def final_run_hook(self, job, status, private_data_dir):
@@ -1612,14 +1502,16 @@ class RunProjectUpdate(BaseTask):
shutil.copytree(cache_subpath, dest_subpath, symlinks=True)
logger.debug('{0} {1} prepared {2} from cache'.format(type(project).__name__, project.pk, dest_subpath))
pdd_plugins_path = os.path.join(job_private_data_dir, 'plugins_path')
if not os.path.exists(pdd_plugins_path):
os.mkdir(pdd_plugins_path)
from awx.playbooks import library
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
# copy the special callback (not stdout type) plugin to get list of collections
pdd_plugins_path = os.path.join(job_private_data_dir, 'plugins_path')
if not os.path.exists(pdd_plugins_path):
os.mkdir(pdd_plugins_path)
from awx.playbooks import library
plugin_file_source = os.path.join(library.__path__[0], 'indirect_instance_count.py')
plugin_file_dest = os.path.join(pdd_plugins_path, 'indirect_instance_count.py')
shutil.copyfile(plugin_file_source, plugin_file_dest)
plugin_file_source = os.path.join(library.__path__._path[0], 'indirect_instance_count.py')
plugin_file_dest = os.path.join(pdd_plugins_path, 'indirect_instance_count.py')
shutil.copyfile(plugin_file_source, plugin_file_dest)
def post_run_hook(self, instance, status):
super(RunProjectUpdate, self).post_run_hook(instance, status)
@@ -1682,7 +1574,7 @@ class RunProjectUpdate(BaseTask):
return params
def build_credentials_list(self, project_update):
if project_update.credential:
if project_update.scm_type == 'insights' and project_update.credential:
return [project_update.credential]
return []
@@ -1865,24 +1757,6 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
# All credentials not used by inventory source injector
return inventory_update.get_extra_credentials()
def populate_workload_identity_tokens(self, additional_credentials=None):
"""Also generate OIDC tokens for the cloud credential.
The cloud credential is not in _credentials (it is handled by the
inventory source injector), but it may still need a workload identity
token generated for it.
"""
cloud_cred = self.instance.get_cloud_credential()
creds = list(additional_credentials or [])
if cloud_cred:
creds.append(cloud_cred)
super().populate_workload_identity_tokens(additional_credentials=creds or None)
# Override get_cloud_credential on this instance so the injector
# uses the credential with OIDC context instead of doing a fresh
# DB fetch that would lose it.
if cloud_cred and cloud_cred.context:
self.instance.get_cloud_credential = lambda: cloud_cred
def build_project_dir(self, inventory_update, private_data_dir):
source_project = None
if inventory_update.inventory_source:

View File

@@ -393,9 +393,9 @@ def evaluate_policy(instance):
raise PolicyEvaluationError(_('Following certificate settings are missing for OPA_AUTH_TYPE=Certificate: {}').format(cert_settings_missing))
query_paths = [
('Organization', instance.organization.opa_query_path if instance.organization else None),
('Inventory', instance.inventory.opa_query_path if instance.inventory else None),
('Job template', instance.job_template.opa_query_path if instance.job_template else None),
('Organization', instance.organization.opa_query_path),
('Inventory', instance.inventory.opa_query_path),
('Job template', instance.job_template.opa_query_path),
]
violations = dict()
errors = dict()

View File

@@ -1,19 +0,0 @@
---
authors:
- AWX Project Contributors <awx-project@googlegroups.com>
dependencies: {}
description: External query testing collection. No embedded query file. Not for use in production.
documentation: https://github.com/ansible/awx
homepage: https://github.com/ansible/awx
issues: https://github.com/ansible/awx
license:
- GPL-3.0-or-later
name: external
namespace: demo
readme: README.md
repository: https://github.com/ansible/awx
tags:
- demo
- testing
- external_query
version: 1.0.0

View File

@@ -1,78 +0,0 @@
#!/usr/bin/python
# Same licensing as AWX
from __future__ import absolute_import, division, print_function
__metaclass__ = type
DOCUMENTATION = r'''
---
module: example
short_description: Module for specific live tests
version_added: "2.0.0"
description: This module is part of a test collection in local source. Used for external query testing.
options:
host_name:
description: Name to return as the host name.
required: false
type: str
author:
- AWX Live Tests
'''
EXAMPLES = r'''
- name: Test with defaults
demo.external.example:
- name: Test with custom host name
demo.external.example:
host_name: foo_host
'''
RETURN = r'''
direct_host_name:
description: The name of the host, this will be collected with the feature.
type: str
returned: always
sample: 'foo_host'
'''
from ansible.module_utils.basic import AnsibleModule
def run_module():
module_args = dict(
host_name=dict(type='str', required=False, default='foo_host_default'),
)
result = dict(
changed=False,
other_data='sample_string',
)
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
if module.check_mode:
module.exit_json(**result)
result['direct_host_name'] = module.params['host_name']
result['nested_host_name'] = {'host_name': module.params['host_name']}
result['name'] = 'vm-foo'
# non-cononical facts
result['device_type'] = 'Fake Host'
module.exit_json(**result)
def main():
run_module()
if __name__ == '__main__':
main()

View File

@@ -1,19 +0,0 @@
---
authors:
- AWX Project Contributors <awx-project@googlegroups.com>
dependencies: {}
description: External query testing collection v1.5.0. No embedded query file. Not for use in production.
documentation: https://github.com/ansible/awx
homepage: https://github.com/ansible/awx
issues: https://github.com/ansible/awx
license:
- GPL-3.0-or-later
name: external
namespace: demo
readme: README.md
repository: https://github.com/ansible/awx
tags:
- demo
- testing
- external_query
version: 1.5.0

View File

@@ -1,78 +0,0 @@
#!/usr/bin/python
# Same licensing as AWX
from __future__ import absolute_import, division, print_function
__metaclass__ = type
DOCUMENTATION = r'''
---
module: example
short_description: Module for specific live tests
version_added: "2.0.0"
description: This module is part of a test collection in local source. Used for external query testing.
options:
host_name:
description: Name to return as the host name.
required: false
type: str
author:
- AWX Live Tests
'''
EXAMPLES = r'''
- name: Test with defaults
demo.external.example:
- name: Test with custom host name
demo.external.example:
host_name: foo_host
'''
RETURN = r'''
direct_host_name:
description: The name of the host, this will be collected with the feature.
type: str
returned: always
sample: 'foo_host'
'''
from ansible.module_utils.basic import AnsibleModule
def run_module():
module_args = dict(
host_name=dict(type='str', required=False, default='foo_host_default'),
)
result = dict(
changed=False,
other_data='sample_string',
)
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
if module.check_mode:
module.exit_json(**result)
result['direct_host_name'] = module.params['host_name']
result['nested_host_name'] = {'host_name': module.params['host_name']}
result['name'] = 'vm-foo'
# non-cononical facts
result['device_type'] = 'Fake Host'
module.exit_json(**result)
def main():
run_module()
if __name__ == '__main__':
main()

View File

@@ -1,19 +0,0 @@
---
authors:
- AWX Project Contributors <awx-project@googlegroups.com>
dependencies: {}
description: External query testing collection v3.0.0. No embedded query file. Not for use in production.
documentation: https://github.com/ansible/awx
homepage: https://github.com/ansible/awx
issues: https://github.com/ansible/awx
license:
- GPL-3.0-or-later
name: external
namespace: demo
readme: README.md
repository: https://github.com/ansible/awx
tags:
- demo
- testing
- external_query
version: 3.0.0

View File

@@ -1,78 +0,0 @@
#!/usr/bin/python
# Same licensing as AWX
from __future__ import absolute_import, division, print_function
__metaclass__ = type
DOCUMENTATION = r'''
---
module: example
short_description: Module for specific live tests
version_added: "2.0.0"
description: This module is part of a test collection in local source. Used for external query testing.
options:
host_name:
description: Name to return as the host name.
required: false
type: str
author:
- AWX Live Tests
'''
EXAMPLES = r'''
- name: Test with defaults
demo.external.example:
- name: Test with custom host name
demo.external.example:
host_name: foo_host
'''
RETURN = r'''
direct_host_name:
description: The name of the host, this will be collected with the feature.
type: str
returned: always
sample: 'foo_host'
'''
from ansible.module_utils.basic import AnsibleModule
def run_module():
module_args = dict(
host_name=dict(type='str', required=False, default='foo_host_default'),
)
result = dict(
changed=False,
other_data='sample_string',
)
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
if module.check_mode:
module.exit_json(**result)
result['direct_host_name'] = module.params['host_name']
result['nested_host_name'] = {'host_name': module.params['host_name']}
result['name'] = 'vm-foo'
# non-cononical facts
result['device_type'] = 'Fake Host'
module.exit_json(**result)
def main():
run_module()
if __name__ == '__main__':
main()

View File

@@ -1,11 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- name: Set artifacts via set_stats
ansible.builtin.set_stats:
data: "{{ stats_data }}"
per_host: false
aggregate: false
when: stats_data is defined

View File

@@ -1,21 +0,0 @@
---
# Generated by Claude Opus 4.6 (claude-opus-4-6).
- hosts: all
vars:
extra_value: ""
gather_facts: false
connection: local
tasks:
- name: set a custom fact
set_fact:
foo: "bar{{ extra_value }}"
bar:
a:
b:
- "c"
- "d"
cacheable: true
- name: sleep to create overlap window for concurrent job testing
wait_for:
timeout: 2

View File

@@ -1,5 +0,0 @@
---
collections:
- name: 'file:///tmp/live_tests/host_query_external_v1_0_0'
type: git
version: devel

View File

@@ -1,8 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- demo.external.example:
register: result
- debug: var=result

View File

@@ -1,5 +0,0 @@
---
collections:
- name: 'file:///tmp/live_tests/host_query_external_v1_5_0'
type: git
version: devel

View File

@@ -1,8 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- demo.external.example:
register: result
- debug: var=result

View File

@@ -1,5 +0,0 @@
---
collections:
- name: 'file:///tmp/live_tests/host_query_external_v3_0_0'
type: git
version: devel

View File

@@ -1,8 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- demo.external.example:
register: result
- debug: var=result

View File

@@ -74,9 +74,9 @@ def temp_analytic_tar():
@pytest.fixture
def mock_analytic_post():
# Patch get_or_generate_candlepin_certificate to skip mTLS path
with mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate', return_value=(None, None)):
yield
# Patch the Session.post method to return a mock response with status_code 200
with mock.patch('awx.main.analytics.core.requests.Session.post', return_value=mock.Mock(status_code=200)) as mock_post:
yield mock_post
@pytest.mark.parametrize(
@@ -141,22 +141,15 @@ def mock_analytic_post():
)
@pytest.mark.django_db
def test_ship_credential(setting_map, expected_result, expected_auth, temp_analytic_tar, mock_analytic_post):
with override_settings(**setting_map, AUTOMATION_ANALYTICS_URL='https://example.com/api'):
with mock.patch('awx.main.analytics.core.OIDCClient') as mock_oidc:
mock_oidc_instance = mock.Mock()
mock_oidc_instance.make_request.return_value = mock.Mock(status_code=200)
mock_oidc.return_value = mock_oidc_instance
with override_settings(**setting_map):
result = ship(temp_analytic_tar)
result = ship(temp_analytic_tar)
assert result == expected_result
if expected_auth:
# Verify OIDC client was instantiated with correct credentials
mock_oidc.assert_called_once_with(expected_auth[0], expected_auth[1])
mock_oidc_instance.make_request.assert_called_once()
else:
# When credentials are missing, OIDCClient should not be called
mock_oidc.assert_not_called()
assert result == expected_result
if expected_auth:
mock_analytic_post.assert_called_once()
assert mock_analytic_post.call_args[1]['auth'] == expected_auth
else:
mock_analytic_post.assert_not_called()
@pytest.mark.django_db

View File

@@ -1,4 +1,3 @@
import os
import pytest
import requests
from unittest import mock
@@ -258,92 +257,3 @@ class TestAnalyticsGenericView:
else:
# assert mock_base_auth_request not called
mock_base_auth_request.assert_not_called()
@pytest.mark.django_db
def test__send_to_analytics_respects_proxy_env_oidc(self):
settings_map = {
'INSIGHTS_TRACKING_STATE': True,
'AUTOMATION_ANALYTICS_URL': 'https://example.com',
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass',
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'AWX_TASK_ENV': {'HTTPS_PROXY': '192.168.50.100:1234', 'HTTP_PROXY': '192.168.50.100:5678'},
}
with override_settings(**settings_map):
request = RequestFactory().post('/some/path')
view = AnalyticsGenericView()
with mock.patch('awx.api.views.analytics.OIDCClient') as mock_oidc_client:
mock_client_instance = mock.Mock()
mock_oidc_client.return_value = mock_client_instance
def _check_env_and_respond(*args, **kwargs):
assert os.environ.get('HTTPS_PROXY') == '192.168.50.100:1234'
assert os.environ.get('HTTP_PROXY') == '192.168.50.100:5678'
return mock.Mock(status_code=200)
mock_client_instance.make_request.side_effect = _check_env_and_respond
response = view._send_to_analytics(request, 'POST')
assert response.status_code == 200
mock_client_instance.make_request.assert_called_once()
@pytest.mark.django_db
def test__send_to_analytics_respects_proxy_env_basic_auth(self):
settings_map = {
'INSIGHTS_TRACKING_STATE': True,
'AUTOMATION_ANALYTICS_URL': 'https://example.com',
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass',
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'AWX_TASK_ENV': {'HTTPS_PROXY': '192.168.50.100:1234'},
}
with override_settings(**settings_map):
request = RequestFactory().post('/some/path')
view = AnalyticsGenericView()
with mock.patch('awx.api.views.analytics.OIDCClient') as mock_oidc_client, mock.patch(
'awx.api.views.analytics.AnalyticsGenericView._base_auth_request'
) as mock_base_auth:
mock_client_instance = mock.Mock()
mock_oidc_client.return_value = mock_client_instance
mock_client_instance.make_request.side_effect = requests.RequestException("OIDC failed")
def _check_env_and_respond(*args, **kwargs):
assert os.environ.get('HTTPS_PROXY') == '192.168.50.100:1234'
return mock.Mock(status_code=200)
mock_base_auth.side_effect = _check_env_and_respond
response = view._send_to_analytics(request, 'POST')
assert response.status_code == 200
mock_base_auth.assert_called_once()
@pytest.mark.django_db
def test__send_to_analytics_restores_env_after_request(self):
original_value = os.environ.pop('HTTPS_PROXY', None)
settings_map = {
'INSIGHTS_TRACKING_STATE': True,
'AUTOMATION_ANALYTICS_URL': 'https://example.com',
'REDHAT_USERNAME': 'redhat_user',
'REDHAT_PASSWORD': 'redhat_pass',
'SUBSCRIPTIONS_CLIENT_ID': '',
'SUBSCRIPTIONS_CLIENT_SECRET': '',
'AWX_TASK_ENV': {'HTTPS_PROXY': '192.168.50.100:1234'},
}
try:
with override_settings(**settings_map):
request = RequestFactory().post('/some/path')
view = AnalyticsGenericView()
with mock.patch('awx.api.views.analytics.OIDCClient') as mock_oidc_client:
mock_client_instance = mock.Mock()
mock_oidc_client.return_value = mock_client_instance
mock_client_instance.make_request.return_value = mock.Mock(status_code=200)
view._send_to_analytics(request, 'POST')
assert 'HTTPS_PROXY' not in os.environ
finally:
if original_value is not None:
os.environ['HTTPS_PROXY'] = original_value

View File

@@ -1,84 +0,0 @@
import pytest
from awx.api.versioning import reverse
from rest_framework import status
from awx.main.models.jobs import JobTemplate
@pytest.mark.django_db
class TestConfigEndpointFields:
def test_base_fields_all_users(self, get, rando):
url = reverse('api:api_v2_config_view')
response = get(url, rando, expect=200)
assert 'time_zone' in response.data
assert 'license_info' in response.data
assert 'version' in response.data
assert 'eula' in response.data
assert 'analytics_status' in response.data
assert 'analytics_collectors' in response.data
assert 'become_methods' in response.data
@pytest.mark.parametrize(
"role_type",
[
"superuser",
"system_auditor",
"org_admin",
"org_auditor",
"org_project_admin",
],
)
def test_privileged_users_conditional_fields(self, get, user, organization, admin, role_type):
url = reverse('api:api_v2_config_view')
if role_type == "superuser":
test_user = admin
elif role_type == "system_auditor":
test_user = user('system-auditor', is_superuser=False)
test_user.is_system_auditor = True
test_user.save()
elif role_type == "org_admin":
test_user = user('org-admin', is_superuser=False)
organization.admin_role.members.add(test_user)
elif role_type == "org_auditor":
test_user = user('org-auditor', is_superuser=False)
organization.auditor_role.members.add(test_user)
elif role_type == "org_project_admin":
test_user = user('org-project-admin', is_superuser=False)
organization.project_admin_role.members.add(test_user)
response = get(url, test_user, expect=200)
assert 'project_base_dir' in response.data
assert 'project_local_paths' in response.data
assert 'custom_virtualenvs' in response.data
def test_job_template_admin_gets_venvs_only(self, get, user, organization, project, inventory):
"""Test that JobTemplate admin without org access gets only custom_virtualenvs"""
jt_admin = user('jt-admin', is_superuser=False)
jt = JobTemplate.objects.create(name='test-jt', organization=organization, project=project, inventory=inventory)
jt.admin_role.members.add(jt_admin)
url = reverse('api:api_v2_config_view')
response = get(url, jt_admin, expect=200)
assert 'custom_virtualenvs' in response.data
assert 'project_base_dir' not in response.data
assert 'project_local_paths' not in response.data
def test_normal_user_no_conditional_fields(self, get, rando):
url = reverse('api:api_v2_config_view')
response = get(url, rando, expect=200)
assert 'project_base_dir' not in response.data
assert 'project_local_paths' not in response.data
assert 'custom_virtualenvs' not in response.data
def test_unauthenticated_denied(self, get):
"""Test that unauthenticated requests are denied"""
url = reverse('api:api_v2_config_view')
response = get(url, None, expect=401)
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -200,7 +200,6 @@ def test_grant_org_credential_to_org_user_through_user_roles(post, credential, o
@pytest.mark.django_db
def test_grant_org_credential_to_non_org_user_through_role_users(post, credential, organization, org_admin, alice):
# NOTE: this endpoint is going away soon
credential.organization = organization
credential.save()
response = post(reverse('api:role_users_list', kwargs={'pk': credential.use_role.id}), {'id': alice.id}, org_admin)
@@ -209,7 +208,6 @@ def test_grant_org_credential_to_non_org_user_through_role_users(post, credentia
@pytest.mark.django_db
def test_grant_org_credential_to_non_org_user_through_user_roles(post, credential, organization, org_admin, alice):
# NOTE: this endpoint is going away soon
credential.organization = organization
credential.save()
response = post(reverse('api:user_roles_list', kwargs={'pk': alice.id}), {'id': credential.use_role.id}, org_admin)
@@ -218,18 +216,18 @@ def test_grant_org_credential_to_non_org_user_through_user_roles(post, credentia
@pytest.mark.django_db
def test_grant_private_credential_to_user_through_role_users(post, credential, alice, bob):
# NOTE: this endpoint is going away soon
# normal users can't do this
credential.admin_role.members.add(alice)
response = post(reverse('api:role_users_list', kwargs={'pk': credential.use_role.id}), {'id': bob.id}, alice)
assert response.status_code == 403
assert response.status_code == 400
@pytest.mark.django_db
def test_grant_private_credential_to_org_user_through_role_users(post, credential, org_admin, org_member):
# NOTE: this endpoint is going away soon
# org admins can't either
credential.admin_role.members.add(org_admin)
response = post(reverse('api:role_users_list', kwargs={'pk': credential.use_role.id}), {'id': org_member.id}, org_admin)
assert response.status_code == 204
assert response.status_code == 400
@pytest.mark.django_db
@@ -241,18 +239,18 @@ def test_sa_grant_private_credential_to_user_through_role_users(post, credential
@pytest.mark.django_db
def test_grant_private_credential_to_user_through_user_roles(post, credential, alice, bob):
# NOTE: this endpoint is going away soon
# normal users can't do this
credential.admin_role.members.add(alice)
response = post(reverse('api:user_roles_list', kwargs={'pk': bob.id}), {'id': credential.use_role.id}, alice)
assert response.status_code == 403
assert response.status_code == 400
@pytest.mark.django_db
def test_grant_private_credential_to_org_user_through_user_roles(post, credential, org_admin, org_member):
# NOTE: this endpoint is going away soon
# org admins can't either
credential.admin_role.members.add(org_admin)
response = post(reverse('api:user_roles_list', kwargs={'pk': org_member.id}), {'id': credential.use_role.id}, org_admin)
assert response.status_code == 204
assert response.status_code == 400
@pytest.mark.django_db
@@ -284,14 +282,14 @@ def test_grant_org_credential_to_team_through_team_roles(post, credential, organ
@pytest.mark.django_db
def test_sa_grant_private_credential_to_team_through_role_teams(post, credential, admin, team):
# NOTE: this endpoint is going away soon
# not even a system admin can grant a private cred to a team though
response = post(reverse('api:role_teams_list', kwargs={'pk': credential.use_role.id}), {'id': team.id}, admin)
assert response.status_code == 204
assert response.status_code == 400
@pytest.mark.django_db
def test_grant_credential_to_team_different_organization_through_role_teams(post, get, credential, organizations, admin, org_admin, team, team_member):
# NOTE: this endpoint is going away soon
# # Test that credential from different org can be assigned to team by a superuser through role_teams_list endpoint
orgs = organizations(2)
credential.organization = orgs[0]
credential.save()
@@ -301,7 +299,10 @@ def test_grant_credential_to_team_different_organization_through_role_teams(post
# Non-superuser (org_admin) trying cross-org assignment should be denied
response = post(reverse('api:role_teams_list', kwargs={'pk': credential.use_role.id}), {'id': team.id}, org_admin)
assert response.status_code == 400
assert "You cannot grant credential access to a Team not in the credentials' organization" in str(response.data['detail'])
assert (
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
in response.data['msg']
)
# Superuser (admin) can do cross-org assignment
response = post(reverse('api:role_teams_list', kwargs={'pk': credential.use_role.id}), {'id': team.id}, admin)
@@ -315,17 +316,20 @@ def test_grant_credential_to_team_different_organization_through_role_teams(post
@pytest.mark.django_db
def test_grant_credential_to_team_different_organization(post, get, credential, organizations, admin, org_admin, team, team_member):
# NOTE: this endpoint is going away soon
# Test that credential from different org can be assigned to team by a superuser
orgs = organizations(2)
credential.organization = orgs[0]
credential.save()
team.organization = orgs[1]
team.save()
# Non-superuser (org_admin) trying cross-org assignment should be denied
# Non-superuser (org_admin, ...) trying cross-org assignment should be denied
response = post(reverse('api:team_roles_list', kwargs={'pk': team.id}), {'id': credential.use_role.id}, org_admin)
assert response.status_code == 400
assert "You cannot grant credential access to a Team not in the credentials' organization" in str(response.data['detail'])
assert (
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
in response.data['msg']
)
# Superuser (system admin) can do cross-org assignment
response = post(reverse('api:team_roles_list', kwargs={'pk': team.id}), {'id': credential.use_role.id}, admin)

View File

@@ -1,7 +1,5 @@
import pytest
from ansible_base.lib.testing.util import feature_flag_enabled, feature_flag_disabled
from awx.main.models import CredentialInputSource
from awx.api.versioning import reverse
@@ -318,60 +316,3 @@ def test_create_credential_input_source_with_already_used_input_returns_400(post
]
all_responses = [post(list_url, params, admin) for params in all_params]
assert all_responses.pop().status_code == 400
@pytest.mark.django_db
def test_credential_input_source_passes_workload_identity_token_when_flag_enabled(vault_credential, external_credential, mocker):
"""Test that workload_identity_token is passed to backend when flag is enabled."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
# Add workload_identity_token as an internal field on the external credential type
# so get_input_value resolves it from the per-input-source context
external_credential.credential_type.inputs['fields'].append(
{'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'internal': True}
)
# Create an input source
input_source = CredentialInputSource.objects.create(
target_credential=vault_credential,
source_credential=external_credential,
input_field_name='vault_password',
metadata={'key': 'test_key'},
)
# Mock the credential plugin backend
mock_backend = mocker.patch.object(external_credential.credential_type.plugin, 'backend', autospec=True, return_value='test_value')
# Call with context keyed by input source PK
test_context = {input_source.pk: {'workload_identity_token': 'jwt_token_here'}}
result = input_source.get_input_value(context=test_context)
# Verify backend was called with workload_identity_token
assert result == 'test_value'
call_kwargs = mock_backend.call_args[1]
assert call_kwargs['workload_identity_token'] == 'jwt_token_here'
assert call_kwargs['key'] == 'test_key'
@pytest.mark.django_db
def test_credential_input_source_skips_workload_identity_token_when_flag_disabled(vault_credential, external_credential, mocker):
"""Test that workload_identity_token is NOT passed when flag is disabled."""
with feature_flag_disabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
# Create an input source
input_source = CredentialInputSource.objects.create(
target_credential=vault_credential,
source_credential=external_credential,
input_field_name='vault_password',
metadata={'key': 'test_key'},
)
# Mock the credential plugin backend
mock_backend = mocker.patch.object(external_credential.credential_type.plugin, 'backend', autospec=True, return_value='test_value')
# Call with context containing workload_identity_token but NO internal field defined,
# simulating a flag-disabled scenario where tokens are not generated upstream
test_context = {input_source.pk: {'workload_identity_token': 'jwt_token_here'}}
result = input_source.get_input_value(context=test_context)
# Verify backend was called WITHOUT workload_identity_token since the credential type
# does not define it as an internal field (flag-disabled path doesn't register it)
assert result == 'test_value'
call_kwargs = mock_backend.call_args[1]
assert 'workload_identity_token' not in call_kwargs
assert call_kwargs['key'] == 'test_key'

View File

@@ -2,7 +2,6 @@ import json
import pytest
from ansible_base.lib.testing.util import feature_flag_enabled
from awx.main.models.credential import CredentialType, Credential
from awx.api.versioning import reverse
@@ -160,8 +159,7 @@ def test_create_as_admin(get, post, admin):
response = get(reverse('api:credential_type_list'), admin)
assert response.data['count'] == 1
assert response.data['results'][0]['name'] == 'Custom Credential Type'
# Serializer normalizes empty inputs to {'fields': []}
assert response.data['results'][0]['inputs'] == {'fields': []}
assert response.data['results'][0]['inputs'] == {}
assert response.data['results'][0]['injectors'] == {}
assert response.data['results'][0]['managed'] is False
@@ -476,98 +474,3 @@ def test_credential_type_rbac_external_test(post, alice, admin, credentialtype_e
data = {'inputs': {}, 'metadata': {}}
assert post(url, data, admin).status_code == 202
assert post(url, data, alice).status_code == 403
# --- Tests for internal field filtering with None/invalid inputs ---
@pytest.mark.django_db
def test_credential_type_with_none_inputs(get, admin):
"""Test that credential type with empty inputs dict works correctly."""
# Create a credential type with empty dict
ct = CredentialType.objects.create(
kind='cloud',
name='Test Type',
managed=False,
inputs={}, # Empty dict, not None (DB has NOT NULL constraint)
)
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
response = get(url, admin)
assert response.status_code == 200
# Should have normalized inputs to empty dict
assert 'inputs' in response.data
assert isinstance(response.data['inputs'], dict)
assert response.data['inputs']['fields'] == []
@pytest.mark.django_db
def test_credential_type_with_invalid_inputs_type(get, admin):
"""Test that credential type with non-dict inputs doesn't cause errors."""
# Create a credential type with invalid inputs type
ct = CredentialType.objects.create(kind='cloud', name='Test Type', managed=False, inputs={'fields': 'not-a-list'})
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
response = get(url, admin)
assert response.status_code == 200
# Should gracefully handle invalid fields type
assert 'inputs' in response.data
assert response.data['inputs']['fields'] == []
@pytest.mark.django_db
def test_credential_type_filters_internal_fields(get, admin):
"""Test that internal fields are filtered from API responses."""
ct = CredentialType.objects.create(
kind='cloud',
name='Test OIDC Type',
managed=False,
inputs={
'fields': [
{'id': 'url', 'label': 'URL', 'type': 'string'},
{'id': 'token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True},
{'id': 'public_field', 'label': 'Public', 'type': 'string'},
]
},
)
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
response = get(url, admin)
assert response.status_code == 200
field_ids = [f['id'] for f in response.data['inputs']['fields']]
# Internal field should be filtered out
assert 'token' not in field_ids
assert 'url' in field_ids
assert 'public_field' in field_ids
@pytest.mark.django_db
def test_credential_type_list_filters_internal_fields(get, admin):
"""Test that internal fields are filtered in list view."""
CredentialType.objects.create(
kind='cloud',
name='Test OIDC Type',
managed=False,
inputs={
'fields': [
{'id': 'url', 'label': 'URL', 'type': 'string'},
{'id': 'workload_identity_token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True},
]
},
)
url = reverse('api:credential_type_list')
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
response = get(url, admin)
assert response.status_code == 200
# Find our credential type in the results
test_ct = next((ct for ct in response.data['results'] if ct['name'] == 'Test OIDC Type'), None)
assert test_ct is not None
field_ids = [f['id'] for f in test_ct['inputs']['fields']]
# Internal field should be filtered out
assert 'workload_identity_token' not in field_ids
assert 'url' in field_ids

View File

@@ -1,34 +0,0 @@
import pytest
from awx.api.versioning import reverse
from awx.main.models import Host, Inventory
@pytest.mark.django_db
def test_dashboard_hosts_total_excludes_constructed(get, admin_user, organization):
"""
Constructed inventory hosts are not counted in the dashboard
"""
source_inv = Inventory.objects.create(name='source-inv', organization=organization)
source_host = source_inv.hosts.create(name='host1')
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
Host.objects.create(name='host1', inventory=constructed, instance_id=str(source_host.pk))
response = get(reverse('api:dashboard_view'), user=admin_user, expect=200)
assert response.data['hosts']['total'] == 1
@pytest.mark.django_db
def test_host_list_still_returns_constructed(get, admin_user, organization):
"""
Constructed inventory hosts are still visible through the API
"""
source_inv = Inventory.objects.create(name='source-inv', organization=organization)
source_host = source_inv.hosts.create(name='host1')
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
Host.objects.create(name='host1', inventory=constructed, instance_id=str(source_host.pk))
response = get(reverse('api:host_list'), user=admin_user, expect=200)
assert response.data['count'] == 2

View File

@@ -7,7 +7,7 @@ from django.core.exceptions import ValidationError
from awx.api.versioning import reverse
from awx.main.models import InventorySource, Inventory, ActivityStream, Organization
from awx.main.models import InventorySource, Inventory, ActivityStream
from awx.main.utils.inventory_vars import update_group_variables
@@ -963,45 +963,3 @@ class TestInventoryAllVariables:
# Test step 6: Value of var x from source A reappears, because the
# latest update from source B did not contain var x.
self.update_and_verify(inv_src_c, {}, expect={"x": 1}, teststep=6)
@pytest.mark.django_db
def test_inventory_names_unique_per_organization(post, admin_user):
"""Validate that two inventories can have the same name if they belong to different organizations."""
org1 = Organization.objects.create(name='org-inv-1')
org2 = Organization.objects.create(name='org-inv-2')
inv_name = 'SharedInventoryName'
# Create inventory with same name in org1
resp1 = post(
reverse('api:inventory_list'),
{'name': inv_name, 'organization': org1.id},
admin_user,
expect=201,
)
inv1_id = resp1.data['id']
# Create inventory with same name in org2 - should succeed
resp2 = post(
reverse('api:inventory_list'),
{'name': inv_name, 'organization': org2.id},
admin_user,
expect=201,
)
inv2_id = resp2.data['id']
assert inv1_id != inv2_id
inv1 = Inventory.objects.get(id=inv1_id)
inv2 = Inventory.objects.get(id=inv2_id)
assert inv1.name == inv2.name == inv_name
assert inv1.organization.id == org1.id
assert inv2.organization.id == org2.id
# Attempt to create another inventory with same name in org1 - should fail
resp3 = post(
reverse('api:inventory_list'),
{'name': inv_name, 'organization': org1.id},
admin_user,
expect=400,
)
assert 'Inventory with this Name and Organization already exists' in json.dumps(resp3.data)

View File

@@ -485,3 +485,47 @@ class TestJobTemplateCallbackProxyIntegration:
expect=400,
**headers
)
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[])
def test_only_first_entry_in_comma_separated_header_is_considered(self, job_template, admin_user, post):
"""
Test that only the first entry in a comma-separated header value is used for host matching.
This is important for X-Forwarded-For style headers where the format is "client, proxy1, proxy2".
Only the original client (first entry) should be matched against inventory hosts.
"""
# Create host that matches the SECOND entry in the comma-separated list
job_template.inventory.hosts.create(name='second-host.example.com')
headers = {
# First entry is 'first-host.example.com', second is 'second-host.example.com'
# Only the first should be considered, so this should NOT match
'HTTP_X_FROM_THE_LOAD_BALANCER': 'first-host.example.com, second-host.example.com',
'REMOTE_ADDR': 'unrelated-addr',
'REMOTE_HOST': 'unrelated-host',
}
# Should return 400 because only 'first-host.example.com' is considered,
# and that host is NOT in the inventory
r = post(
url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=400, **headers
)
assert r.data['msg'] == 'No matching host could be found!'
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[])
def test_first_entry_in_comma_separated_header_matches(self, job_template, admin_user, post):
"""
Test that the first entry in a comma-separated header value correctly matches an inventory host.
"""
# Create host that matches the FIRST entry in the comma-separated list
job_template.inventory.hosts.create(name='first-host.example.com')
headers = {
# First entry is 'first-host.example.com', second is 'second-host.example.com'
# The first entry matches the inventory host
'HTTP_X_FROM_THE_LOAD_BALANCER': 'first-host.example.com, second-host.example.com',
'REMOTE_ADDR': 'unrelated-addr',
'REMOTE_HOST': 'unrelated-host',
}
# Should return 201 because 'first-host.example.com' is the first entry and matches
post(url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=201, **headers)

View File

@@ -1,92 +0,0 @@
# -*- coding: utf-8 -*-
import pytest
from awx.api.versioning import reverse
from awx.main.models import NotificationTemplate, Organization
@pytest.mark.django_db
def test_notification_template_names_unique_per_organization(post, admin_user):
"""
Validate that notification templates must have unique names within an organization,
but can have the same name across different organizations.
"""
org1 = Organization.objects.create(name='org-notif-1')
org2 = Organization.objects.create(name='org-notif-2')
template_name = 'SharedNotificationName'
# Create notification template in org1
resp1 = post(
reverse('api:notification_template_list'),
{
'name': template_name,
'organization': org1.id,
'notification_type': 'email',
'notification_configuration': {
'username': 'user@example.com',
'password': 'pass',
'sender': 'sender@example.com',
'recipients': ['recipient@example.com'],
'host': 'smtp.example.com',
'port': 25,
'use_tls': False,
'use_ssl': False,
},
},
admin_user,
expect=201,
)
template1_id = resp1.data['id']
# Create notification template with same name in org2 - should succeed
resp2 = post(
reverse('api:notification_template_list'),
{
'name': template_name,
'organization': org2.id,
'notification_type': 'email',
'notification_configuration': {
'username': 'user@example.com',
'password': 'pass',
'sender': 'sender@example.com',
'recipients': ['recipient@example.com'],
'host': 'smtp.example.com',
'port': 25,
'use_tls': False,
'use_ssl': False,
},
},
admin_user,
expect=201,
)
template2_id = resp2.data['id']
assert template1_id != template2_id
template1 = NotificationTemplate.objects.get(id=template1_id)
template2 = NotificationTemplate.objects.get(id=template2_id)
assert template1.name == template2.name == template_name
assert template1.organization.id == org1.id
assert template2.organization.id == org2.id
# Attempt to create another notification template with same name in org1 - should fail
resp3 = post(
reverse('api:notification_template_list'),
{
'name': template_name,
'organization': org1.id,
'notification_type': 'email',
'notification_configuration': {
'username': 'user@example.com',
'password': 'pass',
'sender': 'sender@example.com',
'recipients': ['recipient@example.com'],
'host': 'smtp.example.com',
'port': 25,
'use_tls': False,
'use_ssl': False,
},
},
admin_user,
expect=400,
)
assert 'Notification template with this Organization and Name already exists' in str(resp3.data)

View File

@@ -1,163 +0,0 @@
"""
Tests for OIDC workload identity credential type feature flag.
The FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED flag is an install-time flag that
controls whether OIDC credential types are loaded into the registry at startup.
When disabled, OIDC credential types are not loaded and do not exist in the database.
"""
import pytest
from unittest import mock
from django.test import override_settings
from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES
from awx.main.models.credential import CredentialType, ManagedCredentialType, load_credentials
from awx.api.versioning import reverse
@pytest.fixture
def reload_credentials_with_flag(django_db_setup, django_db_blocker):
"""
Fixture that reloads credentials with a specific flag state.
This simulates what happens at application startup.
"""
# Save original registry state
original_registry = ManagedCredentialType.registry.copy()
def _reload(flag_enabled):
with django_db_blocker.unblock():
# Clear the entire registry before reloading
ManagedCredentialType.registry.clear()
# Reload credentials with the specified flag state
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=flag_enabled):
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
load_credentials()
# Sync to database
CredentialType.setup_tower_managed_defaults(lock=False)
# In tests, the session fixture pre-loads all credential types into the DB.
# Remove OIDC types when testing the disabled state so the API test is accurate.
if not flag_enabled:
CredentialType.objects.filter(namespace__in=OIDC_CREDENTIAL_TYPE_NAMESPACES).delete()
yield _reload
# Restore original registry state after tests
ManagedCredentialType.registry.clear()
ManagedCredentialType.registry.update(original_registry)
@pytest.fixture
def isolated_registry():
"""Save and restore the ManagedCredentialType registry, with full isolation via mocked entry_points."""
original_registry = ManagedCredentialType.registry.copy()
ManagedCredentialType.registry.clear()
yield
ManagedCredentialType.registry.clear()
ManagedCredentialType.registry.update(original_registry)
def _make_mock_entry_point(name):
"""Create a mock entry point that mimics a credential plugin."""
ep = mock.MagicMock()
ep.name = name
ep.value = f'test_plugin:{name}'
plugin = mock.MagicMock(spec=[])
ep.load.return_value = plugin
return ep
def _mock_entry_points_factory(managed_names, supported_names):
"""Return a side_effect function for mocking entry_points() with controlled plugins."""
managed = [_make_mock_entry_point(n) for n in managed_names]
supported = [_make_mock_entry_point(n) for n in supported_names]
def _entry_points(group):
if group == 'awx_plugins.managed_credentials':
return managed
elif group == 'awx_plugins.managed_credentials.supported':
return supported
return []
return _entry_points
# --- Unit tests for load_credentials() registry behavior ---
def test_oidc_types_in_registry_when_flag_enabled(isolated_registry):
"""Test that OIDC credential types are added to the registry when flag is enabled."""
mock_eps = _mock_entry_points_factory(
managed_names=['ssh', 'vault'],
supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'],
)
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True):
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps):
load_credentials()
for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES:
assert ns in ManagedCredentialType.registry, f"{ns} should be in registry when flag is enabled"
assert 'ssh' in ManagedCredentialType.registry
assert 'vault' in ManagedCredentialType.registry
def test_oidc_types_not_in_registry_when_flag_disabled(isolated_registry):
"""Test that OIDC credential types are excluded from the registry when flag is disabled."""
mock_eps = _mock_entry_points_factory(
managed_names=['ssh', 'vault'],
supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'],
)
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False):
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps):
load_credentials()
for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES:
assert ns not in ManagedCredentialType.registry, f"{ns} should not be in registry when flag is disabled"
# Non-OIDC types should still be loaded
assert 'ssh' in ManagedCredentialType.registry
assert 'vault' in ManagedCredentialType.registry
def test_oidc_namespaces_constant():
"""Test that OIDC_CREDENTIAL_TYPE_NAMESPACES contains the expected namespaces."""
assert 'hashivault-kv-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES
assert 'hashivault-ssh-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES
assert len(OIDC_CREDENTIAL_TYPE_NAMESPACES) == 2
# --- Functional API tests ---
@pytest.mark.django_db
def test_oidc_types_loaded_when_flag_enabled(get, admin, reload_credentials_with_flag):
"""Test that OIDC credential types are visible in the API when flag is enabled."""
reload_credentials_with_flag(flag_enabled=True)
response = get(reverse('api:credential_type_list'), admin)
assert response.status_code == 200
namespaces = [ct['namespace'] for ct in response.data['results']]
assert 'hashivault-kv-oidc' in namespaces
assert 'hashivault-ssh-oidc' in namespaces
@pytest.mark.django_db
def test_oidc_types_not_loaded_when_flag_disabled(get, admin, reload_credentials_with_flag):
"""Test that OIDC credential types are not visible in the API when flag is disabled."""
reload_credentials_with_flag(flag_enabled=False)
response = get(reverse('api:credential_type_list'), admin)
assert response.status_code == 200
namespaces = [ct['namespace'] for ct in response.data['results']]
assert 'hashivault-kv-oidc' not in namespaces
assert 'hashivault-ssh-oidc' not in namespaces
# Verify they're also not in the database
assert not CredentialType.objects.filter(namespace='hashivault-kv-oidc').exists()
assert not CredentialType.objects.filter(namespace='hashivault-ssh-oidc').exists()

View File

@@ -1,311 +0,0 @@
"""
Tests for OIDC workload identity credential test endpoints.
Tests the /api/v2/credentials/<id>/test/ and /api/v2/credential_types/<id>/test/
endpoints when used with OIDC-enabled credential types.
"""
import pytest
from unittest import mock
from django.test import override_settings
from awx.main.models import Credential, CredentialType, JobTemplate
from awx.api.versioning import reverse
@pytest.fixture
def job_template(organization, project):
"""Job template with organization and project for OIDC JWT generation."""
return JobTemplate.objects.create(name='test-jt', organization=organization, project=project, playbook='helloworld.yml')
@pytest.fixture
def oidc_credentialtype():
"""Create a credential type with workload_identity_token internal field."""
oidc_type_inputs = {
'fields': [
{'id': 'url', 'label': 'Vault URL', 'type': 'string', 'help_text': 'The Vault server URL.'},
{'id': 'auth_path', 'label': 'Auth Path', 'type': 'string', 'help_text': 'JWT auth mount path.'},
{'id': 'role_id', 'label': 'Role ID', 'type': 'string', 'help_text': 'Vault role.'},
{'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'secret': True, 'internal': True},
],
'metadata': [
{'id': 'secret_path', 'label': 'Secret Path', 'type': 'string'},
{'id': 'job_template_id', 'label': 'Job Template ID', 'type': 'string'},
],
'required': ['url', 'auth_path', 'role_id'],
}
class MockPlugin(object):
def backend(self, **kwargs):
# Simulate successful backend call
return 'secret'
with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin:
mock_plugin.return_value = MockPlugin()
oidc_type = CredentialType(kind='external', managed=True, namespace='hashivault-kv-oidc', name='HashiCorp Vault KV (OIDC)', inputs=oidc_type_inputs)
oidc_type.save()
yield oidc_type
@pytest.fixture
def oidc_credential(oidc_credentialtype):
"""Create a credential using the OIDC credential type."""
return Credential.objects.create(
credential_type=oidc_credentialtype,
name='oidc-vault-cred',
inputs={'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role'},
)
@pytest.fixture
def mock_oidc_backend():
"""Fixture that mocks OIDC JWT generation and credential backend."""
with mock.patch('awx.api.views.retrieve_workload_identity_jwt_with_claims') as mock_jwt, mock.patch('awx.api.views._jwt_decode') as mock_decode, mock.patch(
'awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock
) as mock_plugin:
# Set default return values
mock_jwt.return_value = 'fake.jwt.token'
mock_decode.return_value = {'iss': 'http://gateway/o', 'aud': 'vault'}
# Create mock backend
mock_backend = mock.MagicMock()
mock_backend.backend.return_value = 'secret'
mock_plugin.return_value = mock_backend
# Yield all mocks for test customization
yield {
'jwt': mock_jwt,
'decode': mock_decode,
'plugin': mock_plugin,
'backend': mock_backend,
}
# --- Tests for CredentialExternalTest endpoint ---
@pytest.mark.django_db
@override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False)
def test_credential_test_without_oidc_feature_flag(post, admin, oidc_credential):
"""Test that credential test works without OIDC feature flag enabled."""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': '1'}}
with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin:
mock_backend = mock.MagicMock()
mock_backend.backend.return_value = 'secret'
mock_plugin.return_value = mock_backend
response = post(url, data, admin)
assert response.status_code == 202
# Should not contain JWT payload when feature flag is disabled
assert 'details' not in response.data or 'sent_jwt_payload' not in response.data.get('details', {})
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
@pytest.mark.parametrize(
'job_template_id, expected_error',
[
(None, 'Job template ID is required'),
('not-an-integer', 'must be an integer'),
('99999', 'does not exist'),
],
ids=['missing_job_template_id', 'invalid_job_template_id_type', 'nonexistent_job_template_id'],
)
def test_credential_test_job_template_validation(mock_flag, post, admin, oidc_credential, job_template_id, expected_error):
"""Test that invalid job_template_id values return 400 with appropriate error messages."""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret'}}
if job_template_id is not None:
data['metadata']['job_template_id'] = job_template_id
response = post(url, data, admin)
assert response.status_code == 400
assert 'details' in response.data
assert 'error_message' in response.data['details']
assert expected_error in response.data['details']['error_message']
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_test_no_access_to_job_template(mock_flag, post, alice, oidc_credential, job_template):
"""Test that user without access to job template gets 403."""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
# Give alice use permission on credential but not on job template
oidc_credential.use_role.members.add(alice)
response = post(url, data, alice)
assert response.status_code == 403
assert 'You do not have access to job template' in str(response.data)
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
"""Test that successful test returns JWT payload in response."""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
# Customize mock for this test
mock_oidc_backend['decode'].return_value = {
'iss': 'http://gateway/o',
'sub': 'system:serviceaccount:default:awx-operator',
'aud': 'vault',
'job_template_id': job_template.id,
}
response = post(url, data, admin)
assert response.status_code == 202
assert 'details' in response.data
assert 'sent_jwt_payload' in response.data['details']
assert response.data['details']['sent_jwt_payload']['job_template_id'] == job_template.id
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_test_response_does_not_contain_secret_value(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
"""
the OIDC credential test endpoint must not echo the resolved Vault secret back to the caller.
"""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
credential_secret_value = 'CREDENTIAL_SECRET'
mock_oidc_backend['backend'].backend.return_value = credential_secret_value
response = post(url, data, admin)
assert response.status_code == 202
assert 'details' in response.data
assert 'sent_jwt_payload' in response.data['details']
assert 'secret_value' not in response.data['details']
assert credential_secret_value not in str(response.data)
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_test_backend_failure_returns_jwt_and_error(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
"""Test that backend failure still returns JWT payload along with error message."""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
# Make backend fail
mock_oidc_backend['backend'].backend.side_effect = RuntimeError('Connection failed')
response = post(url, data, admin)
assert response.status_code == 400
assert 'details' in response.data
# Both JWT payload and error message should be present
assert 'sent_jwt_payload' in response.data['details']
assert 'error_message' in response.data['details']
assert 'Connection failed' in response.data['details']['error_message']
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_test_jwt_generation_failure(mock_flag, post, admin, oidc_credential, job_template):
"""Test that JWT generation failure returns error without JWT payload."""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
with mock.patch('awx.api.views.OIDCCredentialTestMixin._get_workload_identity_token') as mock_jwt:
mock_jwt.side_effect = RuntimeError('Failed to generate JWT')
response = post(url, data, admin)
assert response.status_code == 400
assert 'details' in response.data
assert 'error_message' in response.data['details']
assert 'Failed to generate JWT' in response.data['details']['error_message']
# No JWT payload when generation fails
assert 'sent_jwt_payload' not in response.data['details']
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_test_job_template_id_not_passed_to_backend(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
"""Test that job_template_id is removed from backend_kwargs."""
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
response = post(url, data, admin)
assert response.status_code == 202
# Check that backend was called without job_template_id but with url and workload_identity_token
call_kwargs = mock_oidc_backend['backend'].backend.call_args[1]
assert 'job_template_id' not in call_kwargs
assert 'url' in call_kwargs
assert 'workload_identity_token' in call_kwargs
# --- Tests for CredentialTypeExternalTest endpoint ---
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_type_test_response_does_not_contain_secret_value(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend):
"""
the credential-type variant of the test endpoint should not return the secret value
"""
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
data = {
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'},
'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)},
}
credential_type_seret_value = 'CREDENTIAL_TYPE_SECRET'
mock_oidc_backend['backend'].backend.return_value = credential_type_seret_value
response = post(url, data, admin)
assert response.status_code == 202
assert 'details' in response.data
assert 'sent_jwt_payload' in response.data['details']
assert 'secret_value' not in response.data['details']
assert credential_type_seret_value not in str(response.data)
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_type_test_missing_job_template_id(mock_flag, post, admin, oidc_credentialtype):
"""Test that missing job_template_id returns 400 for credential type test endpoint."""
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
data = {
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role'},
'metadata': {'secret_path': 'test/secret'},
}
response = post(url, data, admin)
assert response.status_code == 400
assert 'details' in response.data
assert 'error_message' in response.data['details']
assert 'Job template ID is required' in response.data['details']['error_message']
@pytest.mark.django_db
@mock.patch('awx.api.views.flag_enabled', return_value=True)
def test_credential_type_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend):
"""Test that successful credential type test returns JWT payload."""
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
data = {
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role'},
'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)},
}
response = post(url, data, admin)
assert response.status_code == 202
assert 'details' in response.data
assert 'sent_jwt_payload' in response.data['details']
@pytest.mark.django_db
def test_credential_external_test_returns_400_for_non_external_credential(post, admin, credential):
# credential fixture creates a non-external credential (e.g. SSH/vault kind)
url = reverse('api:credential_external_test', kwargs={'pk': credential.pk})
response = post(url, {'metadata': {}}, admin)
assert response.status_code == 400
assert 'not testable' in response.data.get('detail', '').lower()

View File

@@ -145,124 +145,3 @@ def test_delete_ad_hoc_command_in_active_state(ad_hoc_command_factory, delete, a
adhoc = ad_hoc_command_factory(initial_state=status)
url = reverse('api:ad_hoc_command_detail', kwargs={'pk': adhoc.pk})
delete(url, None, admin, expect=403)
@pytest.fixture
def job_with_heavy_fields(job_factory):
job = job_factory()
job.extra_vars = '{"some_var": "some_value"}'
job.artifacts = {"some_artifact": "some_value"}
job.save()
return job
def _job_result(response, job_id):
for row in response.data['results']:
if row['id'] == job_id:
return row
raise AssertionError('job {} not found in {}'.format(job_id, [r['id'] for r in response.data['results']]))
@pytest.mark.django_db
def test_unified_jobs_list_strips_heavy_fields_by_default(get, admin, job_with_heavy_fields):
response = get(reverse('api:unified_job_list') + '?id={}'.format(job_with_heavy_fields.id), admin, expect=200)
row = _job_result(response, job_with_heavy_fields.id)
assert 'artifacts' not in row
assert 'extra_vars' not in row
@pytest.mark.django_db
def test_unified_jobs_list_include_artifacts(get, admin, job_with_heavy_fields):
response = get(
reverse('api:unified_job_list') + '?id={}&include=artifacts'.format(job_with_heavy_fields.id),
admin,
expect=200,
)
row = _job_result(response, job_with_heavy_fields.id)
assert 'artifacts' in row
assert 'extra_vars' not in row
@pytest.mark.django_db
def test_unified_jobs_list_include_extra_vars(get, admin, job_with_heavy_fields):
response = get(
reverse('api:unified_job_list') + '?id={}&include=extra_vars'.format(job_with_heavy_fields.id),
admin,
expect=200,
)
row = _job_result(response, job_with_heavy_fields.id)
assert 'extra_vars' in row
assert 'artifacts' not in row
@pytest.mark.django_db
def test_unified_jobs_list_include_both(get, admin, job_with_heavy_fields):
response = get(
reverse('api:unified_job_list') + '?id={}&include=artifacts,extra_vars'.format(job_with_heavy_fields.id),
admin,
expect=200,
)
row = _job_result(response, job_with_heavy_fields.id)
assert 'artifacts' in row
assert 'extra_vars' in row
@pytest.mark.django_db
def test_unified_jobs_list_include_tolerates_whitespace(get, admin, job_with_heavy_fields):
response = get(
reverse('api:unified_job_list') + '?id={}&include=%20artifacts%20,%20extra_vars%20'.format(job_with_heavy_fields.id),
admin,
expect=200,
)
row = _job_result(response, job_with_heavy_fields.id)
assert 'artifacts' in row
assert 'extra_vars' in row
@pytest.mark.django_db
def test_unified_jobs_list_include_ignores_unknown(get, admin, job_with_heavy_fields):
response = get(
reverse('api:unified_job_list') + '?id={}&include=does_not_exist'.format(job_with_heavy_fields.id),
admin,
expect=200,
)
row = _job_result(response, job_with_heavy_fields.id)
assert 'artifacts' not in row
assert 'extra_vars' not in row
@pytest.mark.django_db
def test_unified_jobs_list_include_does_not_honor_disallowed(get, admin, job_with_heavy_fields):
# event_processing_finished triggers a count(*) on main_jobevent and must
# not be re-enabled via the public ?include= param.
response = get(
reverse('api:unified_job_list') + '?id={}&include=event_processing_finished,job_args,result_traceback'.format(job_with_heavy_fields.id),
admin,
expect=200,
)
row = _job_result(response, job_with_heavy_fields.id)
assert 'event_processing_finished' not in row
assert 'job_args' not in row
assert 'result_traceback' not in row
assert 'artifacts' not in row
assert 'extra_vars' not in row
@pytest.mark.django_db
def test_jobs_list_strips_heavy_fields_by_default(get, admin, job_with_heavy_fields):
response = get(reverse('api:job_list') + '?id={}'.format(job_with_heavy_fields.id), admin, expect=200)
row = _job_result(response, job_with_heavy_fields.id)
assert 'artifacts' not in row
assert 'extra_vars' not in row
@pytest.mark.django_db
def test_jobs_list_include_extra_vars(get, admin, job_with_heavy_fields):
response = get(
reverse('api:job_list') + '?id={}&include=extra_vars'.format(job_with_heavy_fields.id),
admin,
expect=200,
)
row = _job_result(response, job_with_heavy_fields.id)
assert 'extra_vars' in row
assert 'artifacts' not in row

View File

@@ -1,3 +1,4 @@
from datetime import date
from unittest import mock
import pytest
@@ -252,7 +253,7 @@ def test_user_verify_attribute_created(admin, get):
resp = get(reverse('api:user_detail', kwargs={'pk': admin.pk}), admin)
assert resp.data['created'] == admin.date_joined
past = "2020-01-01T00:00:00Z"
past = date(2020, 1, 1).isoformat()
for op, count in (('gt', 1), ('lt', 0)):
resp = get(reverse('api:user_list') + f'?created__{op}={past}', admin)
assert resp.data['count'] == count

View File

@@ -13,7 +13,6 @@ from awx.main.models.workflow import (
WorkflowJobTemplateNode,
)
from awx.main.models.credential import Credential
from awx.main.models.label import Label
from awx.main.scheduler import TaskManager, WorkflowManager, DependencyManager
# Django
@@ -52,31 +51,6 @@ def test_node_accepts_prompted_fields(inventory, project, workflow_job_template,
post(url, {'unified_job_template': job_template.pk, 'limit': 'webservers'}, user=admin_user, expect=201)
@pytest.mark.django_db
def test_node_extra_data_patch_with_unprompted_labels(inventory, project, organization, workflow_job_template, patch, admin_user):
"""AAP-41742: PATCH extra_data on a workflow node should succeed even when
the node has labels associated but the JT has ask_labels_on_launch=False."""
jt = JobTemplate.objects.create(
inventory=inventory,
project=project,
playbook='helloworld.yml',
ask_variables_on_launch=True,
ask_labels_on_launch=False,
)
label = Label.objects.create(name='repro-label', organization=organization)
node = WorkflowJobTemplateNode.objects.create(
workflow_job_template=workflow_job_template,
unified_job_template=jt,
extra_data={'foo': 'bar'},
)
node.labels.add(label)
url = reverse('api:workflow_job_template_node_detail', kwargs={'pk': node.pk})
r = patch(url, {'extra_data': {'foo': 'edited'}}, user=admin_user, expect=200)
assert r.data['extra_data'] == {'foo': 'edited'}
@pytest.mark.django_db
@pytest.mark.parametrize(
"field_name, field_value",

View File

@@ -48,7 +48,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
worker = CallbackBrokerWorker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events}
worker.flush(force=True)
worker.flush()
assert worker.buff.get(InventoryUpdateEvent, []) == []
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
@@ -61,7 +61,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
InventoryUpdateEvent(uuid=str(uuid4()), stdout='good2', **kwargs),
]
worker.buff = {InventoryUpdateEvent: events.copy()}
worker.flush(force=True)
worker.flush()
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
assert InventoryUpdateEvent.objects.filter(uuid=events[1].uuid).count() == 0
assert InventoryUpdateEvent.objects.filter(uuid=events[2].uuid).count() == 1
@@ -71,7 +71,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
worker = CallbackBrokerWorker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events.copy()}
worker.flush(force=True)
worker.flush()
# put current saved event in buffer (error case)
worker.buff = {InventoryUpdateEvent: [InventoryUpdateEvent.objects.get(uuid=events[0].uuid)]}
@@ -113,7 +113,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
with mock.patch.object(InventoryUpdateEvent.objects, 'bulk_create', side_effect=ValueError):
with mock.patch.object(events[0], 'save', side_effect=ValueError):
worker.flush(force=True)
worker.flush()
assert "\x00" not in events[0].stdout

View File

@@ -305,47 +305,6 @@ class TestINIImports:
has_host_group = inventory.groups.get(name='has_a_host')
assert has_host_group.hosts.count() == 1
@mock.patch.object(inventory_import, 'AnsibleInventoryLoader', MockLoader)
def test_overwrite_removes_stale_memberships(self, inventory):
"""When overwrite is enabled, host-group and group-group memberships
that are no longer in the imported data should be removed."""
# First import: parent_group has two children, host_group has two hosts
inventory_import.AnsibleInventoryLoader._data = {
"_meta": {"hostvars": {"host1": {}, "host2": {}}},
"all": {"children": ["ungrouped", "parent_group", "child_a", "child_b", "host_group"]},
"parent_group": {"children": ["child_a", "child_b"]},
"host_group": {"hosts": ["host1", "host2"]},
"ungrouped": {"hosts": []},
}
cmd = inventory_import.Command()
cmd.handle(inventory_id=inventory.pk, source=__file__, overwrite=True)
parent = inventory.groups.get(name='parent_group')
assert set(parent.children.values_list('name', flat=True)) == {'child_a', 'child_b'}
host_grp = inventory.groups.get(name='host_group')
assert set(host_grp.hosts.values_list('name', flat=True)) == {'host1', 'host2'}
# Second import: child_b removed from parent_group, host2 moved out of host_group
inventory_import.AnsibleInventoryLoader._data = {
"_meta": {"hostvars": {"host1": {}, "host2": {}}},
"all": {"children": ["ungrouped", "parent_group", "child_a", "child_b", "host_group"]},
"parent_group": {"children": ["child_a"]},
"host_group": {"hosts": ["host1"]},
"ungrouped": {"hosts": ["host2"]},
}
cmd = inventory_import.Command()
cmd.handle(inventory_id=inventory.pk, source=__file__, overwrite=True)
parent.refresh_from_db()
host_grp.refresh_from_db()
# child_b should be removed from parent_group
assert set(parent.children.values_list('name', flat=True)) == {'child_a'}
# host2 should be removed from host_group
assert set(host_grp.hosts.values_list('name', flat=True)) == {'host1'}
# host2 and child_b should still exist in the inventory, just not in those groups
assert inventory.hosts.filter(name='host2').exists()
assert inventory.groups.filter(name='child_b').exists()
@mock.patch.object(inventory_import, 'AnsibleInventoryLoader', MockLoader)
def test_recursive_group_error(self, inventory):
inventory_import.AnsibleInventoryLoader._data = {

View File

@@ -131,18 +131,14 @@ def test_workflow_creation_permissions(setup_managed_roles, organization, workfl
@pytest.mark.django_db
def test_assign_credential_to_user_of_another_org(setup_managed_roles, credential, admin_user, rando, org_admin, organization, post):
'''Test that a credential can only be assigned to a user in the same organization by non-superusers'''
'''Test that a credential can only be assigned to a user in the same organization'''
# cannot assign credential to rando, as rando is not in the same org as the credential
rd = RoleDefinition.objects.get(name="Credential Admin")
credential.organization = organization
credential.save(update_fields=['organization'])
assert credential.organization not in Organization.access_qs(rando, 'member')
url = django_reverse('roleuserassignment-list')
# superuser can assign cross-org
post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=201)
# non-superuser (org_admin) cannot assign cross-org
resp = post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=org_admin, expect=400)
resp = post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=400)
assert "You cannot grant credential access to a User not in the credentials' organization" in str(resp.data)
# can assign credential to superuser
@@ -150,7 +146,7 @@ def test_assign_credential_to_user_of_another_org(setup_managed_roles, credentia
rando.save()
post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=201)
# can assign credential to org_admin (same org)
# can assign credential to org_admin
assert credential.organization in Organization.access_qs(org_admin, 'member')
post(url=url, data={"user": org_admin.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=201)

View File

@@ -1,41 +0,0 @@
"""
Tests for AAP-68023: host_list_rbac performance optimization.
The host list endpoint fetches the large ansible_facts JSON column
unnecessarily. The HostManager now defers it by default so that
list queries avoid transferring this data from PostgreSQL.
"""
import pytest
from awx.main.models import Host
# ---------------------------------------------------------------------------
# AAP-68023: Verify ansible_facts column is deferred by HostManager
# ---------------------------------------------------------------------------
@pytest.mark.django_db
class TestHostManagerDeferral:
"""AAP-68023: The host list fetches 200+ columns unnecessarily.
The ansible_facts JSON column is large and not used by the list
serializer. HostManager.get_queryset() must defer it so that
every query through Host.objects avoids fetching it by default.
"""
def test_ansible_facts_deferred_by_default(self):
"""ansible_facts should be in the deferred set for default Host queries."""
qs = Host.objects.all()
deferred = qs.query.deferred_loading[0]
assert 'ansible_facts' in deferred, f'ansible_facts should be deferred by the HostManager. ' f'Deferred fields: {deferred}'
def test_ansible_facts_accessible_when_needed(self, inventory):
"""Deferred fields are still accessible — Django fetches on access."""
host = Host.objects.create(
name='facts-host',
inventory=inventory,
ansible_facts={'os': 'linux'},
)
loaded = Host.objects.get(pk=host.pk)
assert loaded.ansible_facts == {'os': 'linux'}

View File

@@ -71,10 +71,8 @@ class TestEvents:
assert s.skipped == 0
for host in Host.objects.all():
latest_summary = JobHostSummary.latest_for_host(host.id)
assert latest_summary is not None
assert latest_summary.job_id == self.job.id
assert latest_summary.host == host
assert host.last_job_id == self.job.id
assert host.last_job_host_summary.host == host
def test_host_summary_generation_with_deleted_hosts(self):
self._generate_hosts(10)
@@ -93,7 +91,8 @@ class TestEvents:
def test_host_summary_generation_with_limit(self):
# Make an inventory with 10 hosts, run a playbook with a --limit
# pointed at *one* host,
# Verify that *only* that host has an associated JobHostSummary.
# Verify that *only* that host has an associated JobHostSummary and that
# *only* that host has an updated value for .last_job.
self._generate_hosts(10)
# by making the playbook_on_stats *only* include Host 1, we're emulating
@@ -106,14 +105,13 @@ class TestEvents:
# be related to the appropriate Host)
assert JobHostSummary.objects.count() == 1
for h in Host.objects.all():
latest_summary = JobHostSummary.latest_for_host(h.id)
if h.name == 'Host 1':
assert latest_summary is not None
assert latest_summary.job_id == self.job.id
assert latest_summary.id == JobHostSummary.objects.first().id
assert h.last_job_id == self.job.id
assert h.last_job_host_summary_id == JobHostSummary.objects.first().id
else:
# all other hosts in the inventory should have no summary
assert latest_summary is None
# all other hosts in the inventory should remain untouched
assert h.last_job_id is None
assert h.last_job_host_summary_id is None
def test_host_metrics_insert(self):
self._generate_hosts(10)

View File

@@ -10,23 +10,9 @@ from django.test.utils import override_settings
@pytest.mark.django_db
def test_multiple_hybrid_instances():
for i in range(3):
Instance.objects.create(hostname=f'foo{i}', node_type='hybrid')
assert is_ha_environment()
@pytest.mark.django_db
def test_double_control_instances():
def test_multiple_instances():
for i in range(2):
Instance.objects.create(hostname=f'foo{i}', node_type='control')
assert is_ha_environment()
@pytest.mark.django_db
def test_mix_hybrid_control_instances():
Instance.objects.create(hostname='control_node', node_type='control')
Instance.objects.create(hostname='hybrid_node', node_type='hybrid')
Instance.objects.create(hostname=f'foo{i}', node_type='hybrid')
assert is_ha_environment()

View File

@@ -1,213 +0,0 @@
import pytest
from django.test.utils import CaptureQueriesContext
from django.db import connection
from django.utils.timezone import now
from awx.main.models import Job, JobEvent, Inventory, Host, JobHostSummary
@pytest.mark.django_db
class TestHostLatestSummaryQuerySet:
"""Tests for HostLatestSummaryQuerySet and Host.latest_summary property."""
def _create_inventory_with_hosts(self, count=5):
inventory = Inventory()
inventory.save()
Host.objects.bulk_create([Host(created=now(), modified=now(), name=f'host-{i}', inventory_id=inventory.id) for i in range(count)])
return inventory
def _run_job(self, inventory, host_names=None):
"""Run a fake job that creates JobHostSummary records for the given hosts."""
if host_names is None:
host_names = list(inventory.hosts.values_list('name', flat=True))
job = Job(inventory=inventory)
job.save()
host_map = dict(inventory.hosts.values_list('name', 'id'))
JobEvent.create_from_data(
job_id=job.pk,
parent_uuid='abc123',
event='playbook_on_stats',
event_data={
'ok': {name: 1 for name in host_names},
'changed': {},
'dark': {},
'failures': {},
'ignored': {},
'processed': {},
'rescued': {},
'skipped': {},
},
host_map=host_map,
).save()
return job
def test_with_latest_summary_id_annotates_hosts(self):
inventory = self._create_inventory_with_hosts(3)
job = self._run_job(inventory)
hosts = Host.objects.filter(inventory=inventory).with_latest_summary_id()
for host in hosts:
assert hasattr(host, '_latest_summary_id')
summary = JobHostSummary.objects.filter(host=host, job=job).first()
assert host._latest_summary_id == summary.id
def test_with_latest_summary_id_returns_most_recent(self):
inventory = self._create_inventory_with_hosts(1)
self._run_job(inventory)
job2 = self._run_job(inventory)
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
latest = JobHostSummary.objects.filter(host_id=host.id).order_by('-id').first()
assert latest.job_id == job2.id
assert host._latest_summary_id == latest.id
def test_with_latest_summary_id_none_for_no_summaries(self):
inventory = self._create_inventory_with_hosts(1)
# No job run — no summaries
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
assert host._latest_summary_id is None
def test_fetch_all_bulk_attaches_summaries(self):
inventory = self._create_inventory_with_hosts(5)
self._run_job(inventory)
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id())
for host in hosts:
assert hasattr(host, '_latest_summary_cache')
assert host._latest_summary_cache is not None
assert isinstance(host._latest_summary_cache, JobHostSummary)
def test_fetch_all_skips_non_annotated_querysets(self):
"""Non-annotated querysets should NOT set _latest_summary_cache,
preserving the per-object fallback in Host.latest_summary."""
inventory = self._create_inventory_with_hosts(3)
self._run_job(inventory)
hosts = list(Host.objects.filter(inventory=inventory))
for host in hosts:
assert not hasattr(host, '_latest_summary_cache')
def test_count_does_not_trigger_fetch_all(self):
"""Calling .count() should not trigger _fetch_all or the bulk-attach logic."""
inventory = self._create_inventory_with_hosts(5)
self._run_job(inventory)
qs = Host.objects.filter(inventory=inventory).with_latest_summary_id()
with CaptureQueriesContext(connection) as ctx:
result = qs.count()
assert result == 5
# count() should produce a single COUNT query, not fetch all rows + summaries
assert len(ctx.captured_queries) == 1
assert 'COUNT' in ctx.captured_queries[0]['sql'].upper()
def test_exists_does_not_trigger_fetch_all(self):
inventory = self._create_inventory_with_hosts(1)
self._run_job(inventory)
qs = Host.objects.filter(inventory=inventory).with_latest_summary_id()
with CaptureQueriesContext(connection) as ctx:
result = qs.exists()
assert result is True
assert len(ctx.captured_queries) == 1
def test_latest_summary_property_uses_cache(self):
"""When loaded via with_latest_summary_id(), Host.latest_summary
should use the bulk-attached cache without extra queries."""
inventory = self._create_inventory_with_hosts(3)
self._run_job(inventory)
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id())
with CaptureQueriesContext(connection) as ctx:
for host in hosts:
summary = host.latest_summary
assert summary is not None
# No additional queries — all data came from the bulk-attach
assert len(ctx.captured_queries) == 0
def test_latest_summary_property_fallback(self):
"""When loaded without annotation, Host.latest_summary should
fall back to a per-object query."""
inventory = self._create_inventory_with_hosts(1)
job = self._run_job(inventory)
host = Host.objects.filter(inventory=inventory).first()
assert not hasattr(host, '_latest_summary_cache')
summary = host.latest_summary
assert summary is not None
assert summary.job_id == job.id
# After first access, the cache should be populated
assert hasattr(host, '_latest_summary_cache')
def test_latest_summary_none_when_no_summaries(self):
inventory = self._create_inventory_with_hosts(1)
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
assert host.latest_summary is None
def test_latest_job_property(self):
inventory = self._create_inventory_with_hosts(1)
job = self._run_job(inventory)
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
assert host.latest_job is not None
assert host.latest_job.id == job.id
def test_latest_job_none_when_no_summaries(self):
inventory = self._create_inventory_with_hosts(1)
host = Host.objects.filter(inventory=inventory).first()
assert host.latest_job is None
def test_bulk_attach_select_related(self):
"""The bulk-attach should select_related job and job__job_template
so accessing them doesn't cause extra queries."""
inventory = self._create_inventory_with_hosts(3)
self._run_job(inventory)
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id())
with CaptureQueriesContext(connection) as ctx:
for host in hosts:
summary = host.latest_summary
_ = summary.job # should not query
assert len(ctx.captured_queries) == 0
def test_chaining_preserves_annotation(self):
"""Chaining .filter() after .with_latest_summary_id() should
preserve the annotation and bulk-attach behavior."""
inventory = self._create_inventory_with_hosts(5)
self._run_job(inventory)
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id().filter(name__startswith='host-').order_by('name'))
assert len(hosts) == 5
for host in hosts:
assert hasattr(host, '_latest_summary_cache')
assert host._latest_summary_cache is not None
def test_multiple_jobs_latest_wins(self):
"""After multiple jobs, latest_summary should return the most recent."""
inventory = self._create_inventory_with_hosts(1)
self._run_job(inventory)
self._run_job(inventory)
job3 = self._run_job(inventory)
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
assert host.latest_summary.job_id == job3.id
def test_partial_host_coverage(self):
"""When a job only touches some hosts, only those hosts get summaries."""
inventory = self._create_inventory_with_hosts(5)
self._run_job(inventory, host_names=['host-0', 'host-1'])
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id().order_by('name'))
with_summary = [h for h in hosts if h.latest_summary is not None]
without_summary = [h for h in hosts if h.latest_summary is None]
assert len(with_summary) == 2
assert len(without_summary) == 3
assert sorted([h.name for h in with_summary]) == ['host-0', 'host-1']

View File

@@ -1,111 +0,0 @@
import pytest
from django.utils.timezone import now
from awx.main.models import Job, JobEvent, JobTemplate, Inventory, Host, JobHostSummary, Project
from awx.api.serializers import HostSerializer
@pytest.mark.django_db
class TestHostSummaryFields:
"""Tests for summary_fields of last_job and last_job_host_summary on HostSerializer."""
def _setup_host_with_job(self, status='canceled'):
inventory = Inventory()
inventory.save()
host = Host(created=now(), modified=now(), name='test-host', inventory=inventory)
host.save()
project = Project(name='test-project')
project.save()
jt = JobTemplate(name='test-jt', inventory=inventory, project=project)
jt.save()
job = Job(inventory=inventory, job_template=jt, status=status)
if status in ('successful', 'failed', 'canceled', 'error'):
job.finished = now()
if status == 'canceled':
job.canceled_on = now()
job.save()
host_map = {host.name: host.id}
JobEvent.create_from_data(
job_id=job.pk,
parent_uuid='abc123',
event='playbook_on_stats',
event_data={
'ok': {host.name: 1},
'changed': {},
'dark': {},
'failures': {},
'ignored': {},
'processed': {},
'rescued': {},
'skipped': {},
},
host_map=host_map,
).save()
summary = JobHostSummary.objects.filter(host=host, job=job).first()
host.last_job = job
host.last_job_host_summary = summary
host.save(update_fields=['last_job', 'last_job_host_summary'])
host.refresh_from_db()
return host, job, summary
def test_last_job_summary_fields_canceled_job(self):
host, job, summary = self._setup_host_with_job(status='canceled')
serializer = HostSerializer()
d = serializer.get_summary_fields(host)
assert 'last_job' in d
last_job = d['last_job']
expected_keys = {'id', 'name', 'description', 'finished', 'status', 'failed', 'canceled_on', 'job_template_id', 'job_template_name'}
assert set(last_job.keys()) == expected_keys, f"Unexpected last_job keys: {set(last_job.keys())}"
assert last_job['id'] == job.id
assert last_job['status'] == 'canceled'
assert last_job['canceled_on'] == job.canceled_on
assert last_job['job_template_id'] == job.job_template.id
assert last_job['job_template_name'] == job.job_template.name
def test_last_job_summary_fields_successful_job(self):
host, job, summary = self._setup_host_with_job(status='successful')
serializer = HostSerializer()
d = serializer.get_summary_fields(host)
assert 'last_job' in d
last_job = d['last_job']
expected_keys = {'id', 'name', 'description', 'finished', 'status', 'failed', 'job_template_id', 'job_template_name'}
assert set(last_job.keys()) == expected_keys, f"Unexpected last_job keys: {set(last_job.keys())}"
assert last_job['id'] == job.id
assert last_job['status'] == 'successful'
assert 'canceled_on' not in last_job, "canceled_on should not appear when None"
def test_last_job_host_summary_fields(self):
host, job, summary = self._setup_host_with_job(status='successful')
serializer = HostSerializer()
d = serializer.get_summary_fields(host)
assert 'last_job_host_summary' in d
last_jhs = d['last_job_host_summary']
assert last_jhs['id'] == summary.id
assert 'failed' in last_jhs
def test_no_summary_fields_without_job(self):
inventory = Inventory()
inventory.save()
host = Host(created=now(), modified=now(), name='lonely-host', inventory=inventory)
host.save()
serializer = HostSerializer()
d = serializer.get_summary_fields(host)
assert 'last_job' not in d
assert 'last_job_host_summary' not in d

View File

@@ -108,28 +108,6 @@ class TestActiveCount:
source.hosts.create(name='remotely-managed-host', inventory=inventory)
assert Host.objects.active_count() == 1
def test_active_count_minus_constructed(self, organization):
"""
Active hosts do not include duplicated hosts from construted inventories.
"""
inv = Inventory.objects.create(name='source-inv', organization=organization)
inv.hosts.create(name='host1')
assert Host.objects.active_count() == 1
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
Host.objects.create(name='host1', inventory=constructed)
assert Host.objects.active_count() == 1
def test_org_active_count_minus_constructed(self, organization):
"""Org-scoped count must also exclude constructed-inventory shadow rows."""
inv = Inventory.objects.create(name='source-inv', organization=organization)
inv.hosts.create(name='host1')
assert Host.objects.org_active_count(organization.id) == 1
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
Host.objects.create(name='host1', inventory=constructed)
assert Host.objects.org_active_count(organization.id) == 1
def test_host_case_insensitivity(self, organization):
inv1 = Inventory.objects.create(name='inv1', organization=organization)
inv2 = Inventory.objects.create(name='inv2', organization=organization)

View File

@@ -1,6 +1,6 @@
import pytest
from awx.main.models import JobTemplate, Job, JobHostSummary, WorkflowJob, Inventory, Host, Project, Organization
from awx.main.models import JobTemplate, Job, JobHostSummary, WorkflowJob, Inventory, Project, Organization
@pytest.mark.django_db
@@ -87,47 +87,3 @@ class TestSlicingModels:
unified_job = job_template.create_unified_job(job_slice_count=2)
assert isinstance(unified_job, Job)
@pytest.mark.django_db
class TestGetSourceHostsForConstructedInventory:
"""Tests for Job.get_source_hosts_for_constructed_inventory"""
def test_returns_source_hosts_via_instance_id(self):
"""Constructed hosts with instance_id pointing to source hosts are resolved correctly."""
org = Organization.objects.create(name='test-org')
inv_input = Inventory.objects.create(organization=org, name='input-inv')
source_host1 = inv_input.hosts.create(name='host1')
source_host2 = inv_input.hosts.create(name='host2')
inv_constructed = Inventory.objects.create(organization=org, name='constructed-inv', kind='constructed')
inv_constructed.input_inventories.add(inv_input)
Host.objects.create(inventory=inv_constructed, name='host1', instance_id=str(source_host1.id))
Host.objects.create(inventory=inv_constructed, name='host2', instance_id=str(source_host2.id))
job = Job.objects.create(name='test-job', inventory=inv_constructed)
result = job.get_source_hosts_for_constructed_inventory()
assert set(result.values_list('id', flat=True)) == {source_host1.id, source_host2.id}
def test_no_inventory_returns_empty(self):
"""A job with no inventory returns an empty queryset."""
job = Job.objects.create(name='test-job')
result = job.get_source_hosts_for_constructed_inventory()
assert result.count() == 0
def test_ignores_hosts_without_instance_id(self):
"""Hosts with empty instance_id are excluded from the result."""
org = Organization.objects.create(name='test-org')
inv_input = Inventory.objects.create(organization=org, name='input-inv')
source_host = inv_input.hosts.create(name='host1')
inv_constructed = Inventory.objects.create(organization=org, name='constructed-inv', kind='constructed')
inv_constructed.input_inventories.add(inv_input)
Host.objects.create(inventory=inv_constructed, name='host1', instance_id=str(source_host.id))
Host.objects.create(inventory=inv_constructed, name='host-no-ref', instance_id='')
job = Job.objects.create(name='test-job', inventory=inv_constructed)
result = job.get_source_hosts_for_constructed_inventory()
assert list(result.values_list('id', flat=True)) == [source_host.id]

View File

@@ -8,7 +8,7 @@ from crum import impersonate
# AWX
from awx.main.models import UnifiedJobTemplate, Job, JobTemplate, WorkflowJobTemplate, Project, WorkflowJob, Schedule, Credential
from awx.api.versioning import reverse
from awx.main.utils.common import get_job_variable_prefixes
from awx.main.constants import JOB_VARIABLE_PREFIXES
@pytest.mark.django_db
@@ -160,13 +160,7 @@ class TestMetaVars:
job = Job.objects.create(name='job', created_by=admin_user)
job.save()
user_vars = [
'_'.join(x)
for x in itertools.product(
get_job_variable_prefixes(),
['user_name', 'user_id', 'user_email', 'user_first_name', 'user_last_name'],
)
]
user_vars = ['_'.join(x) for x in itertools.product(['tower', 'awx'], ['user_name', 'user_id', 'user_email', 'user_first_name', 'user_last_name'])]
for key in user_vars:
assert key in job.awx_meta_vars()
@@ -185,7 +179,7 @@ class TestMetaVars:
workflow_job.workflow_nodes.create(job=job)
data = job.awx_meta_vars()
for name in get_job_variable_prefixes():
for name in JOB_VARIABLE_PREFIXES:
assert data['{}_user_id'.format(name)] == admin_user.id
assert data['{}_user_name'.format(name)] == admin_user.username
assert data['{}_workflow_job_id'.format(name)] == workflow_job.pk
@@ -195,7 +189,7 @@ class TestMetaVars:
schedule = Schedule.objects.create(name='job-schedule', rrule='DTSTART:20171129T155939z\nFREQ=MONTHLY', unified_job_template=job_template)
job = Job.objects.create(name='fake-job', launch_type='workflow', schedule=schedule, job_template=job_template)
data = job.awx_meta_vars()
for name in get_job_variable_prefixes():
for name in JOB_VARIABLE_PREFIXES:
assert data['{}_schedule_id'.format(name)] == schedule.pk
assert '{}_user_name'.format(name) not in data
@@ -207,7 +201,7 @@ class TestMetaVars:
job = Job.objects.create(launch_type='workflow')
workflow_job.workflow_nodes.create(job=job)
result_hash = {}
for name in get_job_variable_prefixes():
for name in JOB_VARIABLE_PREFIXES:
result_hash['{}_job_id'.format(name)] = job.id
result_hash['{}_job_launch_type'.format(name)] = 'workflow'
result_hash['{}_workflow_job_name'.format(name)] = 'workflow-job'

View File

@@ -1,274 +0,0 @@
# Generated by Claude Opus 4.6 (claude-opus-4-6)
#
# Test file for cancel + dependency chain behavior and workflow cancel propagation.
#
# These tests verify:
#
# 1. TaskManager.process_job_dep_failures() correctly distinguishes canceled vs
# failed dependencies in the job_explanation message.
#
# 2. TaskManager.process_pending_tasks() transitions pending jobs with
# cancel_flag=True directly to canceled status.
#
# 3. WorkflowManager + TaskManager together cancel all spawned jobs in a
# workflow and finalize the workflow as canceled.
import pytest
from unittest import mock
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
from awx.main.models import JobTemplate, ProjectUpdate, WorkflowApproval, WorkflowJobTemplate
from awx.main.models.workflow import WorkflowApprovalTemplate
from . import create_job
@pytest.fixture
def scm_on_launch_objects(job_template_factory):
"""Create a job template with a project configured for scm_update_on_launch."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
p = objects.project
p.scm_update_on_launch = True
p.scm_update_cache_timeout = 0
p.save(skip_update=True)
return objects
def _create_job_with_dependency(objects):
"""Create a pending job and run DependencyManager to produce its project update dependency.
Returns (job, project_update).
"""
j = create_job(objects.job_template, dependencies_processed=False)
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
DependencyManager().schedule()
assert j.dependent_jobs.count() == 1
pu = j.dependent_jobs.first()
assert isinstance(pu.get_real_instance(), ProjectUpdate)
return j, pu
@pytest.mark.django_db
class TestCanceledDependencyFailsBlockedJob:
"""When a dependency project update is canceled or failed, the task manager
should fail the blocked job via process_job_dep_failures."""
def test_canceled_dependency_fails_blocked_job(self, controlplane_instance_group, scm_on_launch_objects):
"""A canceled dependency causes the blocked job to be failed with
a 'Previous Task Canceled' explanation."""
j, pu = _create_job_with_dependency(scm_on_launch_objects)
ProjectUpdate.objects.filter(pk=pu.pk).update(status='canceled', cancel_flag=True)
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
TaskManager().schedule()
j.refresh_from_db()
assert j.status == 'failed'
assert 'Previous Task Canceled' in j.job_explanation
def test_failed_dependency_fails_blocked_job(self, controlplane_instance_group, scm_on_launch_objects):
"""A failed dependency causes the blocked job to be failed with
a 'Previous Task Failed' explanation."""
j, pu = _create_job_with_dependency(scm_on_launch_objects)
ProjectUpdate.objects.filter(pk=pu.pk).update(status='failed')
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
TaskManager().schedule()
j.refresh_from_db()
assert j.status == 'failed'
assert 'Previous Task Failed' in j.job_explanation
@pytest.mark.django_db
class TestTaskManagerCancelsPendingJobsWithCancelFlag:
"""When the task manager encounters pending jobs that have cancel_flag set,
it should transition them directly to canceled status."""
def test_pending_job_with_cancel_flag_is_canceled(self, controlplane_instance_group, job_template_factory):
"""A pending job with cancel_flag=True is transitioned to canceled
by the task manager without being started."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
j = create_job(objects.job_template)
j.cancel_flag = True
j.save(update_fields=['cancel_flag'])
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
TaskManager().schedule()
j.refresh_from_db()
assert j.status == 'canceled'
assert 'canceled before it started' in j.job_explanation
assert not mock_start.called
def test_pending_job_without_cancel_flag_is_not_canceled(self, controlplane_instance_group, job_template_factory):
"""A normal pending job without cancel_flag should not be canceled
by the task manager (sanity check)."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
j = create_job(objects.job_template)
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
TaskManager().schedule()
j.refresh_from_db()
assert j.status != 'canceled'
def test_multiple_pending_jobs_with_cancel_flag_bulk_canceled(self, controlplane_instance_group, job_template_factory):
"""Multiple pending jobs with cancel_flag=True are all transitioned
to canceled in a single task manager cycle."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
jt = objects.job_template
jt.allow_simultaneous = True
jt.save()
jobs = []
for _ in range(3):
j = create_job(jt)
j.cancel_flag = True
j.save(update_fields=['cancel_flag'])
jobs.append(j)
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
TaskManager().schedule()
for j in jobs:
j.refresh_from_db()
assert j.status == 'canceled', f"Job {j.id} should be canceled but is {j.status}"
assert 'canceled before it started' in j.job_explanation
assert not mock_start.called
@pytest.mark.django_db
class TestWorkflowCancelFinalizesWorkflow:
"""When a workflow job is canceled, the WorkflowManager cancels spawned child
jobs (setting cancel_flag), the TaskManager transitions those pending jobs to
canceled, and a final WorkflowManager pass finalizes the workflow as canceled."""
def test_cancel_workflow_with_parallel_nodes(self, inventory, project, controlplane_instance_group):
"""Create a workflow with parallel nodes, cancel it after one job is
running, and verify all jobs and the workflow reach canceled status."""
jt = JobTemplate.objects.create(allow_simultaneous=False, inventory=inventory, project=project, playbook='helloworld.yml')
wfjt = WorkflowJobTemplate.objects.create(name='test-cancel-wf')
for _ in range(4):
wfjt.workflow_nodes.create(unified_job_template=jt)
wj = wfjt.create_unified_job()
wj.signal_start()
# TaskManager transitions workflow job to running via start_task
TaskManager().schedule()
wj.refresh_from_db()
assert wj.status == 'running'
# WorkflowManager spawns jobs for all 4 nodes
WorkflowManager().schedule()
assert jt.jobs.count() == 4
# Simulate one job running (blocking the others via allow_simultaneous=False)
first_job = jt.jobs.order_by('created').first()
first_job.status = 'running'
first_job.celery_task_id = 'fake-task-id'
first_job.controller_node = 'test-node'
first_job.save(update_fields=['status', 'celery_task_id', 'controller_node'])
# Cancel the workflow
wj.cancel_flag = True
wj.save(update_fields=['cancel_flag'])
# WorkflowManager sees cancel_flag, calls cancel_node_jobs() which sets
# cancel_flag on all child jobs
with mock.patch('awx.main.models.unified_jobs.UnifiedJob.cancel_dispatcher_process'):
WorkflowManager().schedule()
# The running job won't actually stop in tests (no dispatcher), simulate it
first_job.status = 'canceled'
first_job.save(update_fields=['status'])
# TaskManager processes remaining pending jobs with cancel_flag set
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
DependencyManager().schedule()
TaskManager().schedule()
for job in jt.jobs.all():
job.refresh_from_db()
assert job.status == 'canceled', f"Job {job.id} should be canceled but is {job.status}"
assert not mock_start.called
# Final WorkflowManager pass finalizes the workflow
WorkflowManager().schedule()
wj.refresh_from_db()
assert wj.status == 'canceled'
def test_cancel_workflow_with_approval_node(self, controlplane_instance_group):
"""Create a workflow with a pending approval node and a downstream job
node. Cancel the workflow and verify the approval is directly canceled
by the WorkflowManager (since approvals are excluded from TaskManager),
the downstream node is marked do_not_run, and the workflow finalizes."""
approval_template = WorkflowApprovalTemplate.objects.create(name='test-approval', timeout=0)
wfjt = WorkflowJobTemplate.objects.create(name='test-cancel-approval-wf')
approval_node = wfjt.workflow_nodes.create(unified_job_template=approval_template)
# Add a downstream node (just another approval to keep it simple)
downstream_template = WorkflowApprovalTemplate.objects.create(name='test-downstream', timeout=0)
downstream_node = wfjt.workflow_nodes.create(unified_job_template=downstream_template)
approval_node.success_nodes.add(downstream_node)
wj = wfjt.create_unified_job()
wj.signal_start()
# TaskManager transitions workflow to running
TaskManager().schedule()
wj.refresh_from_db()
assert wj.status == 'running'
# WorkflowManager spawns the approval (root node only, downstream waits)
WorkflowManager().schedule()
assert WorkflowApproval.objects.filter(unified_job_node__workflow_job=wj).count() == 1
approval_job = WorkflowApproval.objects.get(unified_job_node__workflow_job=wj)
assert approval_job.status == 'pending'
# Cancel the workflow
wj.cancel_flag = True
wj.save(update_fields=['cancel_flag'])
# WorkflowManager should cancel the approval directly and mark
# the downstream node as do_not_run
WorkflowManager().schedule()
approval_job.refresh_from_db()
assert approval_job.status == 'canceled', f"Approval should be canceled directly by WorkflowManager but is {approval_job.status}"
# Downstream node should be marked do_not_run with no job spawned
downstream_wj_node = wj.workflow_nodes.get(unified_job_template=downstream_template)
assert downstream_wj_node.do_not_run is True
assert downstream_wj_node.job is None
# Workflow should finalize as canceled in the same pass
wj.refresh_from_db()
assert wj.status == 'canceled'

View File

@@ -1,223 +0,0 @@
"""Functional tests for start_fact_cache / finish_fact_cache.
These tests use real database objects (via pytest-django) and real files
on disk, but do not launch jobs or subprocesses. Fact files are written
by start_fact_cache and then manipulated to simulate ansible output
before calling finish_fact_cache.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
import json
import os
import time
from datetime import timedelta
import pytest
from django.utils.timezone import now
from awx.main.models import Host, Inventory
from awx.main.tasks.facts import start_fact_cache, finish_fact_cache
@pytest.fixture
def artifacts_dir(tmp_path):
d = tmp_path / 'artifacts'
d.mkdir()
return str(d)
@pytest.mark.django_db
class TestFinishFactCacheScoping:
"""finish_fact_cache must only update hosts matched by the provided queryset."""
def test_same_hostname_different_inventories(self, organization, artifacts_dir):
"""Two inventories share a hostname; only the targeted one should be updated.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
inv1 = Inventory.objects.create(organization=organization, name='scope-inv1')
inv2 = Inventory.objects.create(organization=organization, name='scope-inv2')
host1 = inv1.hosts.create(name='shared')
host2 = inv2.hosts.create(name='shared')
# Give both hosts initial facts
for h in (host1, host2):
h.ansible_facts = {'original': True}
h.ansible_facts_modified = now()
h.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
# start_fact_cache writes reference files for inv1's hosts
start_fact_cache(inv1.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv1.id)
# Simulate ansible writing new facts for 'shared'
fact_file = os.path.join(artifacts_dir, 'fact_cache', 'shared')
future = time.time() + 60
with open(fact_file, 'w') as f:
json.dump({'updated': True}, f)
os.utime(fact_file, (future, future))
# finish with inv1's hosts as the queryset
finish_fact_cache(inv1.hosts, artifacts_dir=artifacts_dir, inventory_id=inv1.id)
host1.refresh_from_db()
host2.refresh_from_db()
assert host1.ansible_facts == {'updated': True}
assert host2.ansible_facts == {'original': True}, 'Host in a different inventory was modified despite not being in the queryset'
@pytest.mark.django_db
class TestFinishFactCacheConcurrentProtection:
"""finish_fact_cache must not clear facts that a concurrent job updated."""
def test_no_clear_when_no_file_was_written(self, organization, artifacts_dir):
"""Host with no prior facts should not have facts cleared when file is missing.
Generated by Claude Opus 4.6 (claude-opus-4-6).
start_fact_cache records hosts_cached[host] = False for hosts with no
prior facts (no file written). finish_fact_cache should skip the clear
for these hosts because the missing file is expected, not a clear signal.
"""
inv = Inventory.objects.create(organization=organization, name='concurrent-inv')
host = inv.hosts.create(name='target')
job_created = now() - timedelta(minutes=5)
# start_fact_cache records host with False (no facts → no file written)
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
# Simulate a concurrent job updating this host's facts AFTER our job was created
host.ansible_facts = {'from_concurrent_job': True}
host.ansible_facts_modified = now()
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
# The fact file is missing because start_fact_cache never wrote one.
# finish_fact_cache should skip this host entirely.
finish_fact_cache(
inv.hosts,
artifacts_dir=artifacts_dir,
inventory_id=inv.id,
job_created=job_created,
)
host.refresh_from_db()
assert host.ansible_facts == {'from_concurrent_job': True}, 'Facts were cleared for a host that never had a fact file written'
def test_skip_clear_when_facts_modified_after_job_created(self, organization, artifacts_dir):
"""If a file was written and then deleted, but facts were concurrently updated, skip clear.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
inv = Inventory.objects.create(organization=organization, name='concurrent-written-inv')
host = inv.hosts.create(name='target')
old_time = now() - timedelta(hours=1)
host.ansible_facts = {'original': True}
host.ansible_facts_modified = old_time
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
job_created = now() - timedelta(minutes=5)
# start_fact_cache writes a file (host has facts → True in map)
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
# Remove the fact file (ansible didn't target this host via --limit)
os.remove(os.path.join(artifacts_dir, 'fact_cache', host.name))
# Simulate a concurrent job updating this host's facts AFTER our job was created
host.ansible_facts = {'from_concurrent_job': True}
host.ansible_facts_modified = now()
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
finish_fact_cache(
inv.hosts,
artifacts_dir=artifacts_dir,
inventory_id=inv.id,
job_created=job_created,
)
host.refresh_from_db()
assert host.ansible_facts == {'from_concurrent_job': True}, 'Facts set by a concurrent job were cleared despite ansible_facts_modified > job_created'
def test_clear_when_facts_predate_job(self, organization, artifacts_dir):
"""If facts predate the job, a missing file should still clear them.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
inv = Inventory.objects.create(organization=organization, name='clear-inv')
host = inv.hosts.create(name='stale')
old_time = now() - timedelta(hours=1)
host.ansible_facts = {'stale': True}
host.ansible_facts_modified = old_time
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
job_created = now() - timedelta(minutes=5)
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
# Remove the fact file to simulate ansible's clear_facts
os.remove(os.path.join(artifacts_dir, 'fact_cache', host.name))
finish_fact_cache(
inv.hosts,
artifacts_dir=artifacts_dir,
inventory_id=inv.id,
job_created=job_created,
)
host.refresh_from_db()
assert host.ansible_facts == {}, 'Stale facts should have been cleared when the fact file is missing ' 'and ansible_facts_modified predates job_created'
@pytest.mark.django_db
class TestConstructedInventoryFactCache:
"""finish_fact_cache with a constructed inventory queryset must target source hosts."""
def test_facts_resolve_to_source_host(self, organization, artifacts_dir):
"""Facts must be written to the source host, not the constructed copy.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
from django.db.models.functions import Cast
inv_input = Inventory.objects.create(organization=organization, name='ci-input')
source_host = inv_input.hosts.create(name='webserver')
inv_constructed = Inventory.objects.create(organization=organization, name='ci-constructed', kind='constructed')
inv_constructed.input_inventories.add(inv_input)
constructed_host = Host.objects.create(
inventory=inv_constructed,
name='webserver',
instance_id=str(source_host.id),
)
# Build the same queryset that get_hosts_for_fact_cache uses
id_field = Host._meta.get_field('id')
source_qs = Host.objects.filter(id__in=inv_constructed.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field)))
# Give the source host initial facts so start_fact_cache writes a file
source_host.ansible_facts = {'role': 'web'}
source_host.ansible_facts_modified = now()
source_host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
start_fact_cache(source_qs, artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv_constructed.id)
# Simulate ansible writing updated facts
fact_file = os.path.join(artifacts_dir, 'fact_cache', 'webserver')
future = time.time() + 60
with open(fact_file, 'w') as f:
json.dump({'role': 'web', 'deployed': True}, f)
os.utime(fact_file, (future, future))
finish_fact_cache(source_qs, artifacts_dir=artifacts_dir, inventory_id=inv_constructed.id)
source_host.refresh_from_db()
constructed_host.refresh_from_db()
assert source_host.ansible_facts == {'role': 'web', 'deployed': True}
assert not constructed_host.ansible_facts, f'Facts were stored on the constructed host: {constructed_host.ansible_facts!r}'

View File

@@ -29,30 +29,3 @@ def test_cancel_flag_on_start(jt_linked, caplog):
job = Job.objects.get(id=job.id)
assert job.status == 'canceled'
@pytest.mark.django_db
def test_runjob_run_can_accept_waiting_status(jt_linked, mocker):
"""Test that RunJob.run() can accept a job in 'waiting' status and transition it to 'running'
before the pre_run_hook is called"""
job = jt_linked.create_unified_job()
job.status = 'waiting'
job.save()
status_at_pre_run = None
def capture_status(instance, private_data_dir):
nonlocal status_at_pre_run
instance.refresh_from_db()
status_at_pre_run = instance.status
mock_pre_run = mocker.patch.object(RunJob, 'pre_run_hook', side_effect=capture_status)
task = RunJob()
try:
task.run(job.id)
except Exception:
pass
mock_pre_run.assert_called_once()
assert status_at_pre_run == 'running'

View File

@@ -8,8 +8,6 @@ from awx.main.models.jobs import JobTemplate
from awx.main.models import Organization, Inventory, WorkflowJob, ExecutionEnvironment, Host
from awx.main.scheduler import TaskManager
from django.test import override_settings
@pytest.mark.django_db
@pytest.mark.parametrize('num_hosts, num_queries', [(1, 15), (10, 15)])
@@ -447,185 +445,3 @@ def get_inventory_hosts(get, inv_id, use_user):
data = get(reverse('api:inventory_hosts_list', kwargs={'pk': inv_id}), use_user, expect=200).data
results = [host['id'] for host in data['results']]
return results
@pytest.mark.django_db
def test_bulk_job_launch_respects_settings_limit(job_template, organization, inventory, project, post, patch, get, user):
"""Test that bulk job launch respects BULK_JOB_MAX_LAUNCH setting."""
normal_user = user('normal_user', False)
organization.member_role.members.add(normal_user)
jt = JobTemplate.objects.create(
name='bulk-test-jt',
ask_inventory_on_launch=True,
project=project,
playbook='helloworld.yml',
allow_simultaneous=True,
)
jt.execute_role.members.add(normal_user)
inventory.use_role.members.add(normal_user)
# Test with limit set to 3
with override_settings(BULK_JOB_MAX_LAUNCH=3):
# Attempt to launch 5 jobs when limit is 3 - should fail
jobs = [{'unified_job_template': jt.id, 'inventory': inventory.id} for _ in range(5)]
resp = post(
reverse('api:bulk_job_launch'),
{'name': 'Bulk Job Test', 'jobs': jobs},
normal_user,
expect=400,
)
assert 'Number of requested jobs exceeds system setting' in str(resp.data)
# Test with limit increased to 10
with override_settings(BULK_JOB_MAX_LAUNCH=10):
# Now launching 5 jobs should succeed
jobs = [{'unified_job_template': jt.id, 'inventory': inventory.id} for _ in range(5)]
resp = post(
reverse('api:bulk_job_launch'),
{'name': 'Bulk Job Test', 'jobs': jobs},
normal_user,
expect=201,
)
bulk_job = get(resp.data['url'], normal_user, expect=200).data
# Verify the workflow job was created
assert bulk_job['name'] == 'Bulk Job Test'
# Tests for BulkHostCreateSerializer duplicate detection optimization
@pytest.mark.django_db
def test_bulk_host_create_duplicate_within_batch(organization, inventory, post, user):
"""
Test that duplicate hostnames within the same batch are detected.
This tests the Counter-based duplicate detection logic.
"""
inventory.organization = organization
inv_admin = user('inventory_admin', False)
organization.member_role.members.add(inv_admin)
inventory.admin_role.members.add(inv_admin)
# Try to create hosts where 'duplicate-host' appears twice in the same batch
hosts = [
{'name': 'unique-host-1'},
{'name': 'duplicate-host'},
{'name': 'unique-host-2'},
{'name': 'duplicate-host'}, # Duplicate within batch
]
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inv_admin, expect=400)
assert 'Hostnames must be unique in an inventory' in response.data['__all__'][0]
assert 'duplicate-host' in response.data['__all__'][0]
assert Host.objects.filter(inventory=inventory).count() == 0
@pytest.mark.django_db
def test_bulk_host_create_duplicate_against_existing(organization, inventory, post, user):
"""
Test that duplicate hostnames against existing inventory hosts are detected.
This tests the database query-based duplicate detection.
"""
inventory.organization = organization
inv_admin = user('inventory_admin', False)
organization.member_role.members.add(inv_admin)
inventory.admin_role.members.add(inv_admin)
Host.objects.create(name='existing-host-1', inventory=inventory)
Host.objects.create(name='existing-host-2', inventory=inventory)
# Try to create hosts where one already exists
hosts = [
{'name': 'new-host-1'},
{'name': 'existing-host-1'},
{'name': 'new-host-2'},
]
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inv_admin, expect=400)
assert 'Hostnames must be unique in an inventory' in response.data['__all__'][0]
assert 'existing-host-1' in response.data['__all__'][0]
assert Host.objects.filter(inventory=inventory).count() == 2
@pytest.mark.django_db
def test_bulk_host_create_combined_duplicates(organization, inventory, post, user):
"""
Test detection of both batch-internal duplicates and duplicates against existing hosts.
"""
inventory.organization = organization
inventory_admin = user('inventory_admin', False)
organization.member_role.members.add(inventory_admin)
inventory.admin_role.members.add(inventory_admin)
Host.objects.create(name='existing-host', inventory=inventory)
# Try to create hosts with both types of duplicates
hosts = [
{'name': 'new-host'},
{'name': 'batch-duplicate'},
{'name': 'existing-host'},
{'name': 'batch-duplicate'},
]
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inventory_admin, expect=400)
error_message = response.data['__all__'][0]
assert 'Hostnames must be unique in an inventory' in error_message
assert 'batch-duplicate' in error_message or 'existing-host' in error_message
@pytest.mark.django_db
def test_bulk_host_create_no_duplicates_success(organization, inventory, post, user):
"""
Test that hosts are created successfully when there are no duplicates.
"""
inventory.organization = organization
inventory_admin = user('inventory_admin', False)
organization.member_role.members.add(inventory_admin)
inventory.admin_role.members.add(inventory_admin)
Host.objects.create(name='existing-host-1', inventory=inventory)
Host.objects.create(name='existing-host-2', inventory=inventory)
# Create new hosts with unique names
hosts = [
{'name': 'new-host-1'},
{'name': 'new-host-2'},
{'name': 'new-host-3'},
]
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inventory_admin, expect=201)
assert len(response.data['hosts']) == 3
assert Host.objects.filter(inventory=inventory).count() == 5
assert Host.objects.filter(inventory=inventory, name='new-host-1').exists()
assert Host.objects.filter(inventory=inventory, name='new-host-2').exists()
assert Host.objects.filter(inventory=inventory, name='new-host-3').exists()
@pytest.mark.django_db
def test_bulk_host_create_performance_large_inventory(organization, inventory, post, user, django_assert_max_num_queries):
"""
Test that duplicate detection is performant and doesn't load all hosts.
"""
inventory.organization = organization
inventory_admin = user('inventory_admin', False)
organization.member_role.members.add(inventory_admin)
inventory.admin_role.members.add(inventory_admin)
# Create 10k existing hosts to simulate a reasonably large inventory
from django.utils.timezone import now
_now = now()
existing_hosts = [Host(name=f'existing-host-{i}', inventory=inventory, created=_now, modified=_now) for i in range(10000)]
Host.objects.bulk_create(existing_hosts)
new_hosts = [{'name': f'new-host-{i}'} for i in range(10)]
# The number of queries should be bounded and not scale with inventory size
# This should be around 15-20 queries regardless of whether there are 10k or 500k+ existing hosts
with django_assert_max_num_queries(20):
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': new_hosts}, inventory_admin, expect=201)
assert len(response.data['hosts']) == 10
assert Host.objects.filter(inventory=inventory).count() == 10010

View File

@@ -74,64 +74,47 @@ GLqbpJyX2r3p/Rmo6mLY71SqpA==
@pytest.mark.django_db
def test_default_cred_types():
expected = [
'aim',
'aws',
'aws_secretsmanager_credential',
'azure_kv',
'azure_rm',
'bitbucket_dc_token',
'centrify_vault_kv',
'conjur',
'controller',
'galaxy_api_token',
'gce',
'github_token',
'github_app_lookup',
'gitlab_token',
'gpg_public_key',
'hashivault_kv',
'hashivault_ssh',
'hcp_terraform',
'insights',
'kubernetes_bearer_token',
'net',
'openstack',
'registry',
'rhv',
'satellite6',
'scm',
'ssh',
'terraform',
'thycotic_dsv',
'thycotic_tss',
'vault',
'vmware',
]
assert sorted(CredentialType.defaults.keys()) == sorted(expected)
assert 'hashivault-kv-oidc' not in CredentialType.defaults
assert 'hashivault-ssh-oidc' not in CredentialType.defaults
assert sorted(CredentialType.defaults.keys()) == sorted(
[
'aim',
'aws',
'aws_secretsmanager_credential',
'azure_kv',
'azure_rm',
'bitbucket_dc_token',
'centrify_vault_kv',
'conjur',
'controller',
'galaxy_api_token',
'gce',
'github_token',
'github_app_lookup',
'gitlab_token',
'gpg_public_key',
'hashivault_kv',
'hashivault_ssh',
'hcp_terraform',
'insights',
'kubernetes_bearer_token',
'net',
'openstack',
'registry',
'rhv',
'satellite6',
'scm',
'ssh',
'terraform',
'thycotic_dsv',
'thycotic_tss',
'vault',
'vmware',
]
)
for type_ in CredentialType.defaults.values():
assert type_().managed is True
@pytest.mark.django_db
def test_default_cred_types_with_oidc_enabled():
from django.test import override_settings
from awx.main.models.credential import load_credentials, ManagedCredentialType
original_registry = ManagedCredentialType.registry.copy()
try:
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True):
ManagedCredentialType.registry.clear()
load_credentials()
assert 'hashivault-kv-oidc' in CredentialType.defaults
assert 'hashivault-ssh-oidc' in CredentialType.defaults
finally:
ManagedCredentialType.registry = original_registry
@pytest.mark.django_db
def test_credential_creation(organization_factory):
org = organization_factory('test').organization

View File

@@ -8,7 +8,6 @@ from awx.main.models import (
Instance,
Host,
JobHostSummary,
Inventory,
InventoryUpdate,
InventorySource,
Project,
@@ -18,60 +17,14 @@ from awx.main.models import (
InstanceGroup,
Label,
ExecutionEnvironment,
Credential,
CredentialType,
CredentialInputSource,
Organization,
JobTemplate,
)
from awx.main.tasks import jobs
from awx.main.tasks.system import cluster_node_heartbeat
from awx.main.utils.db import bulk_update_sorted_by_id
from ansible_base.lib.testing.util import feature_flag_enabled, feature_flag_disabled
from django.db import OperationalError
from django.test.utils import override_settings
@pytest.fixture
def job_template_with_credentials():
"""
Factory fixture that creates a job template with specified credentials.
Usage:
job = job_template_with_credentials(ssh_cred, vault_cred)
"""
def _create_job_template(
*credentials, org_name='test-org', project_name='test-project', inventory_name='test-inventory', jt_name='test-jt', playbook='test.yml'
):
"""
Create a job template with the given credentials.
Args:
*credentials: Variable number of Credential objects to attach to the job template
org_name: Name for the organization
project_name: Name for the project
inventory_name: Name for the inventory
jt_name: Name for the job template
playbook: Playbook filename
Returns:
Job instance created from the job template
"""
org = Organization.objects.create(name=org_name)
proj = Project.objects.create(name=project_name, organization=org)
inv = Inventory.objects.create(name=inventory_name, organization=org)
jt = JobTemplate.objects.create(name=jt_name, project=proj, inventory=inv, playbook=playbook)
if credentials:
jt.credentials.add(*credentials)
return jt.create_unified_job()
return _create_job_template
@pytest.mark.django_db
def test_orphan_unified_job_creation(instance, inventory):
job = Job.objects.create(job_template=None, inventory=inventory, name='hi world')
@@ -309,440 +262,3 @@ class TestLaunchConfig:
assert config.execution_environment
# We just write the PK instead of trying to assign an item, that happens on the save
assert config.execution_environment_id == ee.id
@pytest.mark.django_db
def test_base_task_credentials_property(job_template_with_credentials):
"""Test that _credentials property caches credentials and doesn't re-query."""
task = jobs.RunJob()
# Create real credentials
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
# Create a job with credentials using fixture
job = job_template_with_credentials(ssh_cred, vault_cred)
task.instance = job
# First access should build credentials
result1 = task._credentials
assert len(result1) == 2
assert isinstance(result1, list)
# Second access should return cached value (we can verify by checking it's the same list object)
result2 = task._credentials
assert result2 is result1 # Same object reference
@pytest.mark.django_db
def test_run_job_machine_credential(job_template_with_credentials):
"""Test _machine_credential returns ssh credential from cache."""
task = jobs.RunJob()
# Create credentials
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
# Create a job using fixture
job = job_template_with_credentials(ssh_cred, vault_cred)
task.instance = job
# Set cached credentials
task._credentials = [ssh_cred, vault_cred]
# Get machine credential
result = task._machine_credential
assert result == ssh_cred
assert result.credential_type.kind == 'ssh'
@pytest.mark.django_db
def test_run_job_machine_credential_none(job_template_with_credentials):
"""Test _machine_credential returns None when no ssh credential exists."""
task = jobs.RunJob()
# Create only vault credential
vault_type = CredentialType.defaults['vault']()
vault_type.save()
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
job = job_template_with_credentials(vault_cred)
task.instance = job
# Set cached credentials
task._credentials = [vault_cred]
# Get machine credential
result = task._machine_credential
assert result is None
@pytest.mark.django_db
def test_run_job_vault_credentials(job_template_with_credentials):
"""Test _vault_credentials returns all vault credentials from cache."""
task = jobs.RunJob()
# Create credentials
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_cred1 = Credential.objects.create(credential_type=vault_type, name='vault-1')
vault_cred2 = Credential.objects.create(credential_type=vault_type, name='vault-2')
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
job = job_template_with_credentials(vault_cred1, ssh_cred, vault_cred2)
task.instance = job
# Set cached credentials
task._credentials = [vault_cred1, ssh_cred, vault_cred2]
# Get vault credentials
result = task._vault_credentials
assert len(result) == 2
assert vault_cred1 in result
assert vault_cred2 in result
assert ssh_cred not in result
@pytest.mark.django_db
def test_run_job_network_credentials(job_template_with_credentials):
"""Test _network_credentials returns all network credentials from cache."""
task = jobs.RunJob()
# Create credentials
net_type = CredentialType.defaults['net']()
net_type.save()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
net_cred = Credential.objects.create(credential_type=net_type, name='net-cred')
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
job = job_template_with_credentials(net_cred, ssh_cred)
task.instance = job
# Set cached credentials
task._credentials = [net_cred, ssh_cred]
# Get network credentials
result = task._network_credentials
assert len(result) == 1
assert result[0] == net_cred
@pytest.mark.django_db
def test_run_job_cloud_credentials(job_template_with_credentials):
"""Test _cloud_credentials returns all cloud credentials from cache."""
task = jobs.RunJob()
# Create credentials
aws_type = CredentialType.defaults['aws']()
aws_type.save()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
aws_cred = Credential.objects.create(credential_type=aws_type, name='aws-cred')
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
job = job_template_with_credentials(aws_cred, ssh_cred)
task.instance = job
# Set cached credentials
task._credentials = [aws_cred, ssh_cred]
# Get cloud credentials
result = task._cloud_credentials
assert len(result) == 1
assert result[0] == aws_cred
@pytest.mark.django_db
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
def test_populate_workload_identity_tokens_with_flag_enabled(job_template_with_credentials, mocker):
"""Test populate_workload_identity_tokens sets context when flag is enabled."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create credential types
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
# Create a workload identity credential type
hashivault_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_type.save()
# Create credentials
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'url': 'https://vault.example.com'})
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
# Create input source linking source credential to target credential
input_source = CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
)
# Create a job using fixture
job = job_template_with_credentials(target_cred, ssh_cred)
task.instance = job
# Override cached_property so the loop uses these exact Python objects
task._credentials = [target_cred, ssh_cred]
# Mock only the HTTP response from the Gateway workload identity endpoint
mock_response = mocker.Mock(status_code=200)
mock_response.json.return_value = {'jwt': 'eyJ.test.jwt'}
mock_request = mocker.patch('requests.request', return_value=mock_response, autospec=True)
task.populate_workload_identity_tokens()
# Verify the HTTP call was made to the correct endpoint
mock_request.assert_called_once()
call_kwargs = mock_request.call_args.kwargs
assert call_kwargs['method'] == 'POST'
assert '/api/gateway/v1/workload_identity_tokens' in call_kwargs['url']
# Verify context was set on the credential, keyed by input source PK
assert input_source.pk in target_cred.context
assert target_cred.context[input_source.pk]['workload_identity_token'] == 'eyJ.test.jwt'
@pytest.mark.django_db
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
def test_populate_workload_identity_tokens_passes_workload_ttl_from_job_timeout(job_template_with_credentials, mocker):
"""Test populate_workload_identity_tokens passes workload_ttl_seconds from get_instance_timeout to the client."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
hashivault_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'url': 'https://vault.example.com'})
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
)
job = job_template_with_credentials(target_cred, ssh_cred)
job.timeout = 3600
job.save()
task.instance = job
task._credentials = [target_cred, ssh_cred]
mock_response = mocker.Mock(status_code=200)
mock_response.json.return_value = {'jwt': 'eyJ.test.jwt'}
mock_request = mocker.patch('requests.request', return_value=mock_response, autospec=True)
task.populate_workload_identity_tokens()
call_kwargs = mock_request.call_args.kwargs
assert call_kwargs['method'] == 'POST'
json_body = call_kwargs.get('json', {})
assert json_body.get('workload_ttl_seconds') == 3600
@pytest.mark.django_db
def test_populate_workload_identity_tokens_with_flag_disabled(job_template_with_credentials):
"""Test populate_workload_identity_tokens sets error status when flag is disabled."""
with feature_flag_disabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create credential types
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
# Create a workload identity credential type
hashivault_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_type.save()
# Create credentials
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source')
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
# Create input source linking source credential to target credential
# Note: Creates the relationship that will trigger the feature flag check
CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
)
# Create a job using fixture
job = job_template_with_credentials(target_cred)
task.instance = job
# Set cached credentials
task._credentials = [target_cred]
task.populate_workload_identity_tokens()
# Verify job status was set to error
job.refresh_from_db()
assert job.status == 'error'
assert 'FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED' in job.job_explanation
assert 'vault-source' in job.job_explanation
@pytest.mark.django_db
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
def test_populate_workload_identity_tokens_multiple_input_sources_per_credential(job_template_with_credentials, mocker):
"""Test that a single credential with two input sources from different workload identity
credential types gets a separate JWT token for each input source, keyed by input source PK."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create credential types
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
# Create two different workload identity credential types
hashivault_kv_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_kv_type.save()
hashivault_ssh_type = CredentialType(
name='HashiCorp Vault Signed SSH (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_ssh_type.save()
# Create source credentials with different audiences
source_cred_kv = Credential.objects.create(credential_type=hashivault_kv_type, name='vault-kv-source', inputs={'url': 'https://vault-kv.example.com'})
source_cred_ssh = Credential.objects.create(
credential_type=hashivault_ssh_type, name='vault-ssh-source', inputs={'url': 'https://vault-ssh.example.com'}
)
# Create target credential that uses both sources for different fields
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
# Create two input sources on the same target credential, each for a different field
input_source_password = CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred_kv, input_field_name='password', metadata={'path': 'secret/data/password'}
)
input_source_ssh_key = CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred_ssh, input_field_name='ssh_key_data', metadata={'path': 'secret/data/ssh_key'}
)
# Create a job using fixture
job = job_template_with_credentials(target_cred)
task.instance = job
# Override cached_property so the loop uses this exact Python object
task._credentials = [target_cred]
# Mock HTTP responses - return different JWTs for each call
response_kv = mocker.Mock(status_code=200)
response_kv.json.return_value = {'jwt': 'eyJ.kv.jwt'}
response_ssh = mocker.Mock(status_code=200)
response_ssh.json.return_value = {'jwt': 'eyJ.ssh.jwt'}
mock_request = mocker.patch('requests.request', side_effect=[response_kv, response_ssh], autospec=True)
task.populate_workload_identity_tokens()
# Verify two separate HTTP calls were made (one per input source)
assert mock_request.call_count == 2
# Verify each call used the correct audience from its source credential
audiences_requested = {call.kwargs.get('json', {}).get('audience', '') for call in mock_request.call_args_list}
assert 'https://vault-kv.example.com' in audiences_requested
assert 'https://vault-ssh.example.com' in audiences_requested
# Verify context on the target credential has both tokens, keyed by input source PK
assert input_source_password.pk in target_cred.context
assert input_source_ssh_key.pk in target_cred.context
assert target_cred.context[input_source_password.pk]['workload_identity_token'] == 'eyJ.kv.jwt'
assert target_cred.context[input_source_ssh_key.pk]['workload_identity_token'] == 'eyJ.ssh.jwt'
@pytest.mark.django_db
def test_populate_workload_identity_tokens_without_workload_identity_credentials(job_template_with_credentials, mocker):
"""Test populate_workload_identity_tokens does nothing when no workload identity credentials."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create only standard credentials (no workload identity)
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
# Create a job using fixture
job = job_template_with_credentials(ssh_cred, vault_cred)
task.instance = job
# Set cached credentials
task._credentials = [ssh_cred, vault_cred]
mocker.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 123}, autospec=True)
task.populate_workload_identity_tokens()
# Verify no context was set
assert not hasattr(ssh_cred, '_context') or ssh_cred.context == {}
assert not hasattr(vault_cred, '_context') or vault_cred.context == {}

View File

@@ -18,14 +18,13 @@ from awx.main.tests.functional.conftest import * # noqa
from awx.main.tests.conftest import load_all_credentials # noqa: F401; pylint: disable=unused-import
from awx.main.tests import data
from awx.main.models import Project, JobTemplate, Organization, Inventory, WorkflowJob, UnifiedJob
from awx.main.models import Project, JobTemplate, Organization, Inventory
from awx.main.tasks.system import clear_setting_cache
logger = logging.getLogger(__name__)
PROJ_DATA = os.path.join(os.path.dirname(data.__file__), 'projects')
COLL_DATA = os.path.join(os.path.dirname(data.__file__), 'collections')
def _copy_folders(source_path, dest_path, clear=False):
@@ -57,7 +56,6 @@ def live_tmp_folder():
shutil.rmtree(path)
os.mkdir(path)
_copy_folders(PROJ_DATA, path)
_copy_folders(COLL_DATA, path)
for dirname in os.listdir(path):
source_dir = os.path.join(path, dirname)
subprocess.run(GIT_COMMANDS, cwd=source_dir, shell=True)
@@ -102,21 +100,6 @@ def wait_for_events(uj, timeout=2):
def unified_job_stdout(uj):
if type(uj) is UnifiedJob:
uj = uj.get_real_instance()
if isinstance(uj, WorkflowJob):
outputs = []
for node in uj.workflow_job_nodes.all().select_related('job').order_by('id'):
if node.job is None:
continue
outputs.append(
'workflow node {node_id} job {job_id} output:\n{output}'.format(
node_id=node.id,
job_id=node.job.id,
output=unified_job_stdout(node.job),
)
)
return '\n'.join(outputs)
wait_for_events(uj)
return '\n'.join([event.stdout for event in uj.get_event_queryset().order_by('created')])

View File

@@ -1,351 +0,0 @@
"""Tests for concurrent fact caching with --limit.
Reproduces bugs where concurrent jobs targeting different hosts via --limit
incorrectly modify (clear or revert) facts for hosts outside their limit.
Customer report: concurrent jobs on the same job template with different limits
cause facts set by an earlier-finishing job to be rolled back when the
later-finishing job completes.
See: https://github.com/jritter/concurrent-aap-fact-caching
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
import logging
import pytest
from django.utils.timezone import now
from awx.api.versioning import reverse
from awx.main.models import Inventory, JobTemplate
from awx.main.tests.live.tests.conftest import wait_for_job, wait_to_leave_status
logger = logging.getLogger(__name__)
@pytest.fixture
def concurrent_facts_inventory(default_org):
"""Inventory with two hosts for concurrent fact cache testing."""
inv_name = 'test_concurrent_fact_cache'
Inventory.objects.filter(organization=default_org, name=inv_name).delete()
inv = Inventory.objects.create(organization=default_org, name=inv_name)
inv.hosts.create(name='cc_host_0')
inv.hosts.create(name='cc_host_1')
return inv
@pytest.fixture
def concurrent_facts_jt(concurrent_facts_inventory, live_tmp_folder, post, admin, project_factory):
"""Job template configured for concurrent fact-cached runs."""
proj = project_factory(scm_url=f'file://{live_tmp_folder}/facts')
if proj.current_job:
wait_for_job(proj.current_job)
assert 'gather_slow.yml' in proj.playbooks, f'gather_slow.yml not in {proj.playbooks}'
jt_name = 'test_concurrent_fact_cache JT'
existing_jt = JobTemplate.objects.filter(name=jt_name).first()
if existing_jt:
existing_jt.delete()
result = post(
reverse('api:job_template_list'),
{
'name': jt_name,
'project': proj.id,
'playbook': 'gather_slow.yml',
'inventory': concurrent_facts_inventory.id,
'use_fact_cache': True,
'allow_simultaneous': True,
},
admin,
expect=201,
)
return JobTemplate.objects.get(id=result.data['id'])
def test_concurrent_limit_does_not_clear_facts(concurrent_facts_inventory, concurrent_facts_jt):
"""Concurrent jobs with different --limit must not clear each other's facts.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Scenario:
- Inventory has cc_host_0 and cc_host_1, neither has prior facts
- Job A runs gather_slow.yml with limit=cc_host_0
- While Job A is still running (sleeping), Job B launches with limit=cc_host_1
- Both jobs set cacheable facts, but only for their respective limited host
- After both complete, BOTH hosts should have populated facts
The bug: get_hosts_for_fact_cache() returns ALL inventory hosts regardless
of --limit. start_fact_cache records them all in hosts_cached but writes
no fact files (no prior facts). When the later-finishing job runs
finish_fact_cache, it sees a missing fact file for the other job's host
and clears that host's facts.
"""
inv = concurrent_facts_inventory
jt = concurrent_facts_jt
# Launch Job A targeting cc_host_0
job_a = jt.create_unified_job()
job_a.limit = 'cc_host_0'
job_a.save(update_fields=['limit'])
job_a.signal_start()
# Wait for Job A to reach running (it will sleep inside the playbook)
wait_to_leave_status(job_a, 'pending')
wait_to_leave_status(job_a, 'waiting')
logger.info(f'Job A (id={job_a.id}) is now running with limit=cc_host_0')
# Launch Job B targeting cc_host_1 while Job A is still running
job_b = jt.create_unified_job()
job_b.limit = 'cc_host_1'
job_b.save(update_fields=['limit'])
job_b.signal_start()
# Verify that Job A is still running when Job B starts,
# otherwise the overlap that triggers the bug did not happen.
wait_to_leave_status(job_b, 'pending')
wait_to_leave_status(job_b, 'waiting')
job_a.refresh_from_db()
if job_a.status != 'running':
pytest.skip('Job A finished before Job B started running; overlap did not occur')
logger.info(f'Job B (id={job_b.id}) is now running with limit=cc_host_1 (concurrent with Job A)')
# Wait for both to complete
wait_for_job(job_a)
wait_for_job(job_b)
# Verify facts survived concurrent execution
host_0 = inv.hosts.get(name='cc_host_0')
host_1 = inv.hosts.get(name='cc_host_1')
# sanity
job_a.refresh_from_db()
job_b.refresh_from_db()
assert job_a.limit == "cc_host_0"
assert job_b.limit == "cc_host_1"
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
assert discovered_foos == ['bar'] * 2, f'Unexpected facts on cc_host_0 or _1: {discovered_foos} after job a,b {job_a.id}, {job_b.id}'
def test_concurrent_limit_does_not_revert_facts(live_tmp_folder, run_job_from_playbook, concurrent_facts_inventory):
"""Concurrent jobs must not revert facts that a prior concurrent job just set.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Scenario:
- First, populate both hosts with initial facts (foo=bar) via a
non-concurrent gather run
- Then run two concurrent jobs with different limits, each setting
a new value (foo=bar_v2 via extra_vars)
- After both complete, BOTH hosts should have foo=bar_v2
The bug: start_fact_cache writes the OLD facts (foo=bar) into each job's
artifact dir for ALL hosts. If ansible's cache plugin rewrites a non-limited
host's fact file with the stale content (updating the mtime), finish_fact_cache
treats it as a legitimate update and overwrites the DB with old values.
"""
# --- Seed both hosts with initial facts via a non-concurrent run ---
inv = concurrent_facts_inventory
scm_url = f'file://{live_tmp_folder}/facts'
res = run_job_from_playbook(
'seed_facts_for_revert_test',
'gather_slow.yml',
scm_url=scm_url,
jt_params={'use_fact_cache': True, 'allow_simultaneous': True, 'inventory': inv.id},
)
for host in inv.hosts.all():
assert host.ansible_facts.get('foo') == 'bar', f'Seed run failed to set facts on {host.name}: {host.ansible_facts}'
job = res['job']
wait_for_job(job)
# sanity, jobs should be set up to both have facts with just bar
host_0 = inv.hosts.get(name='cc_host_0')
host_1 = inv.hosts.get(name='cc_host_1')
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
assert discovered_foos == ['bar'] * 2, f'Facts did not get expected initial values: {discovered_foos}'
jt = job.job_template
assert jt.allow_simultaneous is True
assert jt.use_fact_cache is True
# Sanity assertion, sometimes this would give problems from the Django rel cache
assert jt.project
# --- Run two concurrent jobs that write a new value ---
# Update the JT to pass extra_vars that change the fact value
jt.extra_vars = '{"extra_value": "_v2"}'
jt.save(update_fields=['extra_vars'])
job_a = jt.create_unified_job()
job_a.limit = 'cc_host_0'
job_a.save(update_fields=['limit'])
job_a.signal_start()
wait_to_leave_status(job_a, 'pending')
wait_to_leave_status(job_a, 'waiting')
job_b = jt.create_unified_job()
job_b.limit = 'cc_host_1'
job_b.save(update_fields=['limit'])
job_b.signal_start()
wait_to_leave_status(job_b, 'pending')
wait_to_leave_status(job_b, 'waiting')
job_a.refresh_from_db()
if job_a.status != 'running':
pytest.skip('Job A finished before Job B started running; overlap did not occur')
wait_for_job(job_a)
wait_for_job(job_b)
host_0 = inv.hosts.get(name='cc_host_0')
host_1 = inv.hosts.get(name='cc_host_1')
# Both hosts should have the UPDATED value, not the old seed value
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
assert discovered_foos == ['bar_v2'] * 2, f'Facts were reverted to stale values by concurrent job cc_host_0 or cc_host_1: {discovered_foos}'
def test_fact_cache_scoped_to_inventory(live_tmp_folder, default_org, run_job_from_playbook):
"""finish_fact_cache must not modify hosts in other inventories.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Bug: finish_fact_cache queries Host.objects.filter(name__in=host_names)
without an inventory_id filter, so hosts with the same name in different
inventories get their facts cross-contaminated.
"""
shared_name = 'scope_shared_host'
# Prepare for test by deleting junk from last run
for inv_name in ('test_fact_scope_inv1', 'test_fact_scope_inv2'):
inv = Inventory.objects.filter(name=inv_name).first()
if inv:
inv.delete()
inv1 = Inventory.objects.create(organization=default_org, name='test_fact_scope_inv1')
inv1.hosts.create(name=shared_name)
inv2 = Inventory.objects.create(organization=default_org, name='test_fact_scope_inv2')
host2 = inv2.hosts.create(name=shared_name)
# Give inv2's host distinct facts that should not be touched
original_facts = {'source': 'inventory_2', 'untouched': True}
host2.ansible_facts = original_facts
host2.ansible_facts_modified = now()
host2.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
# Run a fact-gathering job against inv1 only
run_job_from_playbook(
'test_fact_scope',
'gather.yml',
scm_url=f'file://{live_tmp_folder}/facts',
jt_params={'use_fact_cache': True, 'inventory': inv1.id},
)
# inv1's host should have facts
host1 = inv1.hosts.get(name=shared_name)
assert host1.ansible_facts, f'inv1 host should have facts after gather: {host1.ansible_facts}'
# inv2's host must NOT have been touched
host2.refresh_from_db()
assert host2.ansible_facts == original_facts, (
f'Host in a different inventory was modified by a fact cache operation '
f'on another inventory sharing the same hostname. '
f'Expected {original_facts!r}, got {host2.ansible_facts!r}'
)
def test_constructed_inventory_facts_saved_to_source_host(live_tmp_folder, default_org, run_job_from_playbook):
"""Facts from a constructed inventory job must be saved to the source host.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Constructed inventories contain hosts that are references (via instance_id)
to 'real' hosts in input inventories. start_fact_cache correctly resolves
source hosts via get_hosts_for_fact_cache(), but finish_fact_cache must also
write facts back to the source hosts, not the constructed inventory's copies.
Scenario:
- Two input inventories each have a host named 'ci_shared_host'
- A constructed inventory uses both as inputs
- The inventory sync picks one source host (via instance_id) for the
constructed host — which one depends on input processing order
- Both source hosts start with distinct pre-existing facts
- A fact-gathering job runs against the constructed inventory
- After completion, the targeted source host should have the job's facts
- The OTHER source host must retain its original facts untouched
- The constructed host itself must NOT have facts stored on it
(constructed hosts are transient — recreated on each inventory sync)
"""
shared_name = 'ci_shared_host'
# Cleanup from prior runs
for inv_name in ('test_ci_facts_input1', 'test_ci_facts_input2', 'test_ci_facts_constructed'):
Inventory.objects.filter(name=inv_name).delete()
# --- Create two input inventories, each with an identically-named host ---
inv1 = Inventory.objects.create(organization=default_org, name='test_ci_facts_input1')
source_host1 = inv1.hosts.create(name=shared_name)
inv2 = Inventory.objects.create(organization=default_org, name='test_ci_facts_input2')
source_host2 = inv2.hosts.create(name=shared_name)
# Give both hosts distinct pre-existing facts so we can detect cross-contamination
host1_original_facts = {'source': 'inventory_1'}
source_host1.ansible_facts = host1_original_facts
source_host1.ansible_facts_modified = now()
source_host1.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
host2_original_facts = {'source': 'inventory_2'}
source_host2.ansible_facts = host2_original_facts
source_host2.ansible_facts_modified = now()
source_host2.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
source_hosts_by_id = {source_host1.id: source_host1, source_host2.id: source_host2}
original_facts_by_id = {source_host1.id: host1_original_facts, source_host2.id: host2_original_facts}
# --- Create constructed inventory (sync will create hosts from inputs) ---
constructed_inv = Inventory.objects.create(
organization=default_org,
name='test_ci_facts_constructed',
kind='constructed',
)
constructed_inv.input_inventories.add(inv1)
constructed_inv.input_inventories.add(inv2)
# --- Run a fact-gathering job against the constructed inventory ---
# The job launch triggers an inventory sync which creates the constructed
# host with instance_id pointing to one of the source hosts.
scm_url = f'file://{live_tmp_folder}/facts'
run_job_from_playbook(
'test_ci_facts',
'gather.yml',
scm_url=scm_url,
jt_params={'use_fact_cache': True, 'inventory': constructed_inv.id},
)
# --- Determine which source host the constructed host points to ---
constructed_host = constructed_inv.hosts.get(name=shared_name)
target_id = int(constructed_host.instance_id)
other_id = (set(source_hosts_by_id.keys()) - {target_id}).pop()
target_host = source_hosts_by_id[target_id]
other_host = source_hosts_by_id[other_id]
target_host.refresh_from_db()
other_host.refresh_from_db()
constructed_host.refresh_from_db()
actual = [target_host.ansible_facts.get('foo'), other_host.ansible_facts, constructed_host.ansible_facts]
expected = ['bar', original_facts_by_id[other_id], {}]
assert actual == expected, (
f'Constructed inventory fact cache wrote to wrong host(s). '
f'target source host (id={target_id}) foo={actual[0]!r}, '
f'other source host (id={other_id}) facts={actual[1]!r}, '
f'constructed host facts={actual[2]!r}; expected {expected!r}'
)

View File

@@ -1,347 +0,0 @@
"""
Integration tests for external query file functionality (AAP-58470).
Tests verify the end-to-end external query file workflow for indirect node
counting using real AWX job execution. A fixture-created vendor collection
at /var/lib/awx/vendor_collections/ provides external query files, simulating
what the build-time (AAP-58426) and deployment (AAP-58557) integrations will
provide once available.
Test data:
- Collection 'demo.external' at various versions (no embedded query)
- External query files in mock redhat.indirect_accounting collection
"""
import os
import shutil
import time
import yaml
import pytest
from flags.state import enable_flag, disable_flag, flag_enabled
from awx.main.tests.live.tests.conftest import wait_for_events, unified_job_stdout
from awx.main.tasks.host_indirect import save_indirect_host_entries
from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit
from awx.main.models.event_query import EventQuery
from awx.main.models import Job
# --- Constants ---
EXTERNAL_QUERY_JQ = '{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {device_type: .device_type}}'
EXTERNAL_QUERY_CONTENT = yaml.dump(
{'demo.external.example': {'query': EXTERNAL_QUERY_JQ}},
default_flow_style=False,
)
# For precedence test: different jq (no device_type in facts) so we can detect which query was used
EXTERNAL_QUERY_FOR_DEMO_QUERY_JQ = '{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {}}'
EXTERNAL_QUERY_FOR_DEMO_QUERY_CONTENT = yaml.dump(
{'demo.query.example': {'query': EXTERNAL_QUERY_FOR_DEMO_QUERY_JQ}},
default_flow_style=False,
)
VENDOR_COLLECTIONS_BASE = '/var/lib/awx/vendor_collections'
# --- Fixtures ---
@pytest.fixture
def enable_indirect_host_counting():
"""Enable FEATURE_INDIRECT_NODE_COUNTING_ENABLED flag for the test.
Only creates a FlagState DB record if the flag isn't already enabled
(e.g. via development_defaults.py), to avoid UniqueViolation errors
and to avoid leaking state to other tests.
"""
flag_name = "FEATURE_INDIRECT_NODE_COUNTING_ENABLED"
was_enabled = flag_enabled(flag_name)
if not was_enabled:
enable_flag(flag_name)
yield
if not was_enabled:
disable_flag(flag_name)
@pytest.fixture
def vendor_collections_dir():
"""Set up mock redhat.indirect_accounting collection at /var/lib/awx/vendor_collections/.
Creates the collection structure with external query files:
- demo.external.1.0.0.yml (exact match for v1.0.0)
- demo.external.1.1.0.yml (fallback target for v1.5.0)
- demo.query.0.0.1.yml (for precedence test with embedded-query collection)
"""
base = os.path.join(VENDOR_COLLECTIONS_BASE, 'ansible_collections', 'redhat', 'indirect_accounting')
queries_path = os.path.join(base, 'extensions', 'audit', 'external_queries')
meta_path = os.path.join(base, 'meta')
os.makedirs(queries_path, exist_ok=True)
os.makedirs(meta_path, exist_ok=True)
# galaxy.yml for valid collection structure
with open(os.path.join(base, 'galaxy.yml'), 'w') as f:
yaml.dump(
{
'namespace': 'redhat',
'name': 'indirect_accounting',
'version': '1.0.0',
'description': 'Test fixture for external query integration tests',
'authors': ['AWX Tests'],
'dependencies': {},
},
f,
)
# meta/runtime.yml
with open(os.path.join(meta_path, 'runtime.yml'), 'w') as f:
yaml.dump({'requires_ansible': '>=2.15.0'}, f)
# External query files for demo.external collection
for version in ('1.0.0', '1.1.0'):
with open(os.path.join(queries_path, f'demo.external.{version}.yml'), 'w') as f:
f.write(EXTERNAL_QUERY_CONTENT)
# External query file for demo.query collection (precedence test)
with open(os.path.join(queries_path, 'demo.query.0.0.1.yml'), 'w') as f:
f.write(EXTERNAL_QUERY_FOR_DEMO_QUERY_CONTENT)
yield base
# Cleanup: only remove the collection we created, not the entire vendor root
shutil.rmtree(base, ignore_errors=True)
@pytest.fixture(autouse=True)
def cleanup_test_data():
"""Clean up EventQuery and IndirectManagedNodeAudit records after each test."""
yield
EventQuery.objects.filter(fqcn='demo.external').delete()
EventQuery.objects.filter(fqcn='demo.query').delete()
IndirectManagedNodeAudit.objects.filter(job__name__icontains='external_query').delete()
# --- Helpers ---
def run_external_query_job(run_job_from_playbook, live_tmp_folder, test_name, project_dir, jt_params=None):
"""Run a job and return the Job object after waiting for indirect host processing."""
scm_url = f'file://{live_tmp_folder}/{project_dir}'
run_job_from_playbook(test_name, 'run_task.yml', scm_url=scm_url, jt_params=jt_params)
job = Job.objects.filter(name__icontains=test_name).order_by('-created').first()
assert job is not None, f'Job not found for test {test_name}'
wait_for_events(job)
return job
def wait_for_indirect_processing(job, expect_records=True, timeout=5):
"""Wait for indirect host processing to complete.
Follows the same pattern as test_indirect_host_counting.py:53-72.
"""
# Ensure indirect host processing runs (wait_for_events already called by caller)
job.refresh_from_db()
if job.event_queries_processed is False:
save_indirect_host_entries.delay(job.id, wait_for_events=False)
if expect_records:
# Poll for audit records to appear
for _ in range(20):
if IndirectManagedNodeAudit.objects.filter(job=job).exists():
break
time.sleep(0.25)
else:
raise RuntimeError(f'No IndirectManagedNodeAudit records populated for job_id={job.id}')
else:
# For negative tests, wait a reasonable time to confirm no records appear
time.sleep(timeout)
job.refresh_from_db()
# --- AC8.1: External query populates IndirectManagedNodeAudit correctly ---
def test_external_query_populates_audit_table(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.1: Job using demo.external.example with external query file populates
IndirectManagedNodeAudit table correctly.
Uses demo.external v1.0.0 with exact-match external query file demo.external.1.0.0.yml.
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_1',
'test_host_query_external_v1_0_0',
)
wait_for_indirect_processing(job, expect_records=True)
# Verify installed_collections captured demo.external
assert 'demo.external' in job.installed_collections
assert 'host_query' in job.installed_collections['demo.external']
# Verify IndirectManagedNodeAudit records
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
assert host_audit.canonical_facts == {'host_name': 'foo_host_default'}
assert host_audit.facts == {'device_type': 'Fake Host'}
assert host_audit.name == 'vm-foo'
assert host_audit.organization == job.organization
assert 'demo.external.example' in host_audit.events
# --- AC8.2: Precedence - embedded query takes precedence over external ---
def test_embedded_query_takes_precedence(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.2: When collection has both embedded and external query files,
the embedded query takes precedence.
Uses demo.query v0.0.1 which HAS an embedded query (extensions/audit/event_query.yml).
An external query (demo.query.0.0.1.yml) also exists but uses a different jq expression
(no device_type in facts). By checking the audit record's facts, we verify which query was used.
"""
# Run with demo.query collection (has embedded query)
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_2',
'test_host_query',
)
wait_for_indirect_processing(job, expect_records=True)
# Verify the embedded query was used (includes device_type in facts)
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
assert host_audit.facts == {'device_type': 'Fake Host'}, (
'Expected embedded query output (with device_type). ' 'If facts is {}, the external query was incorrectly used instead.'
)
# --- AC8.3: Version fallback to compatible version ---
def test_fallback_to_compatible_version(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.3: Job using collection version with no exact query file falls back
correctly to compatible version.
Uses demo.external v1.5.0. No demo.external.1.5.0.yml exists, but
demo.external.1.1.0.yml is available (same major version, highest <= 1.5.0).
The fallback should find and use the 1.1.0 query.
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_3',
'test_host_query_external_v1_5_0',
)
wait_for_indirect_processing(job, expect_records=True)
# Verify installed_collections captured demo.external at v1.5.0
assert 'demo.external' in job.installed_collections
assert job.installed_collections['demo.external']['version'] == '1.5.0'
# Verify IndirectManagedNodeAudit records were created via fallback
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
assert host_audit.canonical_facts == {'host_name': 'foo_host_default'}
assert host_audit.facts == {'device_type': 'Fake Host'}
assert host_audit.name == 'vm-foo'
# --- AC8.4: Fallback queries don't overcount ---
def test_fallback_does_not_overcount(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.4: Fallback queries don't count MORE nodes than exact-version queries.
Runs two jobs:
1. Exact match scenario (demo.external v1.0.0 -> demo.external.1.0.0.yml)
2. Fallback scenario (demo.external v1.5.0 -> falls back to demo.external.1.1.0.yml)
Verifies that fallback record count <= exact record count.
"""
# Run exact-match job
exact_job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_4_exact',
'test_host_query_external_v1_0_0',
)
wait_for_indirect_processing(exact_job, expect_records=True)
exact_count = IndirectManagedNodeAudit.objects.filter(job=exact_job).count()
# Run fallback job
fallback_job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_4_fallback',
'test_host_query_external_v1_5_0',
)
wait_for_indirect_processing(fallback_job, expect_records=True)
fallback_count = IndirectManagedNodeAudit.objects.filter(job=fallback_job).count()
# Critical safety check: fallback must never count MORE than exact
assert fallback_count <= exact_count, (
f'Overcounting detected! Fallback produced {fallback_count} records ' f'but exact match produced only {exact_count} records.'
)
# Both use the same jq expression and same module, so counts should be equal
assert exact_count == fallback_count
# --- AC8.5: Warning logs contain correct version information ---
def test_fallback_log_contains_version_info(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.5: Warning logs contain correct version information when fallback is used.
Runs a job with verbosity=1 so callback plugin verbose output is captured.
Verifies the log contains the installed version (1.5.0), fallback version (1.1.0),
and collection FQCN (demo.external).
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_5',
'test_host_query_external_v1_5_0',
jt_params={'verbosity': 1},
)
wait_for_indirect_processing(job, expect_records=True)
# Get job stdout to check for fallback log message
stdout = unified_job_stdout(job)
# The callback plugin emits: "Using external query {version_used} for {fqcn} v{ver}."
assert '1.1.0' in stdout, f'Fallback version 1.1.0 not found in job stdout. stdout:\n{stdout}'
assert 'demo.external' in stdout, f'Collection FQCN demo.external not found in job stdout. stdout:\n{stdout}'
assert '1.5.0' in stdout, f'Installed version 1.5.0 not found in job stdout. stdout:\n{stdout}'
# --- AC8.6: No counting when no compatible fallback exists ---
def test_no_counting_without_compatible_fallback(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.6: No counting occurs when no compatible fallback exists.
Uses demo.external v3.0.0 with only v1.x external query files available.
Since major versions differ (3 vs 1), no fallback should occur and no
IndirectManagedNodeAudit records should be created.
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_6',
'test_host_query_external_v3_0_0',
)
wait_for_indirect_processing(job, expect_records=False)
# No audit records should exist for this job
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 0, (
'IndirectManagedNodeAudit records were created despite no compatible ' 'fallback existing for demo.external v3.0.0 (only v1.x queries available).'
)

View File

@@ -1,206 +0,0 @@
import json
import pytest
from awx.main.tests.live.tests.conftest import wait_for_job
from awx.main.models import JobTemplate, WorkflowJobTemplate, WorkflowJobTemplateNode
JT_NAMES = ('artifact-test-first', 'artifact-test-second', 'artifact-test-reader')
WFT_NAMES = ('artifact-test-outer-wf', 'artifact-test-inner-wf')
@pytest.mark.django_db(transaction=True)
def test_nested_workflow_set_stats_precedence(live_tmp_folder, demo_inv, project_factory, default_org):
"""Reproducer for set_stats artifacts from an outer workflow leaking into
an inner (child) workflow and overriding the inner workflow's own artifacts.
Outer WF: [job_first] --success--> [inner_wf]
Inner WF: [job_second] --success--> [job_reader]
job_first sets via set_stats:
var1: "outer-only" (only source, should propagate through)
var2: "should-be-overridden" (will be overridden by job_second)
job_second sets via set_stats:
var2: "from-inner" (should override outer's value)
var3: "inner-only" (only source, should be available)
job_reader runs debug.yml (no set_stats), we inspect its extra_vars:
var1 should be "outer-only" - outer artifacts propagate when uncontested
var2 should be "from-inner" - inner artifacts override outer (THE BUG)
var3 should be "inner-only" - inner-only artifacts propagate normally
"""
# Clean up resources from prior runs (delete individually for signals)
for name in WFT_NAMES:
for wft in WorkflowJobTemplate.objects.filter(name=name):
wft.delete()
for name in JT_NAMES:
for jt in JobTemplate.objects.filter(name=name):
jt.delete()
proj = project_factory(scm_url=f'file://{live_tmp_folder}/debug')
if proj.current_job:
wait_for_job(proj.current_job)
# job_first: sets var1 (outer-only) and var2 (to be overridden by inner)
jt_first = JobTemplate.objects.create(
name='artifact-test-first',
project=proj,
playbook='set_stats.yml',
inventory=demo_inv,
extra_vars=json.dumps({'stats_data': {'var1': 'outer-only', 'var2': 'should-be-overridden'}}),
)
# job_second: overrides var2, introduces var3
jt_second = JobTemplate.objects.create(
name='artifact-test-second',
project=proj,
playbook='set_stats.yml',
inventory=demo_inv,
extra_vars=json.dumps({'stats_data': {'var2': 'from-inner', 'var3': 'inner-only'}}),
)
# job_reader: just runs, we check what extra_vars it receives
jt_reader = JobTemplate.objects.create(
name='artifact-test-reader',
project=proj,
playbook='debug.yml',
inventory=demo_inv,
)
# Inner WFT: job_second -> job_reader
inner_wft = WorkflowJobTemplate.objects.create(name='artifact-test-inner-wf', organization=default_org)
inner_node_1 = WorkflowJobTemplateNode.objects.create(
workflow_job_template=inner_wft,
unified_job_template=jt_second,
identifier='second',
)
inner_node_2 = WorkflowJobTemplateNode.objects.create(
workflow_job_template=inner_wft,
unified_job_template=jt_reader,
identifier='reader',
)
inner_node_1.success_nodes.add(inner_node_2)
# Outer WFT: job_first -> inner_wf
outer_wft = WorkflowJobTemplate.objects.create(name='artifact-test-outer-wf', organization=default_org)
outer_node_1 = WorkflowJobTemplateNode.objects.create(
workflow_job_template=outer_wft,
unified_job_template=jt_first,
identifier='first',
)
outer_node_2 = WorkflowJobTemplateNode.objects.create(
workflow_job_template=outer_wft,
unified_job_template=inner_wft,
identifier='inner',
)
outer_node_1.success_nodes.add(outer_node_2)
# Launch and wait
outer_wfj = outer_wft.create_unified_job()
outer_wfj.signal_start()
wait_for_job(outer_wfj, running_timeout=120)
# Find the reader job inside the inner workflow
inner_wf_node = outer_wfj.workflow_job_nodes.get(identifier='inner')
inner_wfj = inner_wf_node.job
assert inner_wfj is not None, 'Inner workflow job was never created'
# Check that root node of inner WF (job_second) received outer artifacts
second_node = inner_wfj.workflow_job_nodes.get(identifier='second')
assert second_node.job is not None, 'Second job was never created'
second_extra_vars = json.loads(second_node.job.extra_vars)
assert second_extra_vars.get('var1') == 'outer-only', (
f'Root node var1: expected "outer-only" (outer artifact should be available to root node), '
f'got "{second_extra_vars.get("var1")}". '
f'Outer artifacts are not reaching root nodes of child workflows.'
)
reader_node = inner_wfj.workflow_job_nodes.get(identifier='reader')
assert reader_node.job is not None, 'Reader job was never created'
reader_extra_vars = json.loads(reader_node.job.extra_vars)
# var1: only set by outer job_first, no conflict — should propagate through
assert reader_extra_vars.get('var1') == 'outer-only', f'var1: expected "outer-only" (uncontested outer artifact), ' f'got "{reader_extra_vars.get("var1")}"'
# var2: set by outer as "should-be-overridden", then by inner as "from-inner"
# Inner workflow's own ancestor artifacts should take precedence
assert reader_extra_vars.get('var2') == 'from-inner', (
f'var2: expected "from-inner" (inner workflow artifact should override outer), '
f'got "{reader_extra_vars.get("var2")}". '
f'Outer workflow artifacts are leaking via wj_special_vars. '
f'reader node ancestor_artifacts={reader_node.ancestor_artifacts}'
)
# var3: only set by inner job_second — should propagate normally
assert reader_extra_vars.get('var3') == 'inner-only', f'var3: expected "inner-only" (inner-only artifact), ' f'got "{reader_extra_vars.get("var3")}"'
@pytest.mark.django_db(transaction=True)
def test_workflow_extra_vars_override_artifacts(live_tmp_folder, demo_inv, project_factory, default_org):
"""Workflow extra_vars should take precedence over set_stats artifacts
within a single (non-nested) workflow.
WF (extra_vars: my_var="from-wf-extra-vars"):
[job_setter] --success--> [job_reader]
job_setter sets my_var="from-set-stats" via set_stats
job_reader should see my_var="from-wf-extra-vars" because workflow
extra_vars are higher precedence than ancestor artifacts.
"""
wft_name = 'artifact-test-wf-extra-vars-precedence'
jt_names = ('artifact-test-setter', 'artifact-test-checker')
for wft in WorkflowJobTemplate.objects.filter(name=wft_name):
wft.delete()
for name in jt_names:
for jt in JobTemplate.objects.filter(name=name):
jt.delete()
proj = project_factory(scm_url=f'file://{live_tmp_folder}/debug')
if proj.current_job:
wait_for_job(proj.current_job)
jt_setter = JobTemplate.objects.create(
name='artifact-test-setter',
project=proj,
playbook='set_stats.yml',
inventory=demo_inv,
extra_vars=json.dumps({'stats_data': {'my_var': 'from-set-stats'}}),
)
jt_checker = JobTemplate.objects.create(
name='artifact-test-checker',
project=proj,
playbook='debug.yml',
inventory=demo_inv,
)
wft = WorkflowJobTemplate.objects.create(
name=wft_name,
organization=default_org,
extra_vars=json.dumps({'my_var': 'from-wf-extra-vars'}),
)
node_1 = WorkflowJobTemplateNode.objects.create(
workflow_job_template=wft,
unified_job_template=jt_setter,
identifier='setter',
)
node_2 = WorkflowJobTemplateNode.objects.create(
workflow_job_template=wft,
unified_job_template=jt_checker,
identifier='checker',
)
node_1.success_nodes.add(node_2)
wfj = wft.create_unified_job()
wfj.signal_start()
wait_for_job(wfj, running_timeout=120)
checker_node = wfj.workflow_job_nodes.get(identifier='checker')
assert checker_node.job is not None, 'Checker job was never created'
checker_extra_vars = json.loads(checker_node.job.extra_vars)
assert checker_extra_vars.get('my_var') == 'from-wf-extra-vars', (
f'Expected my_var="from-wf-extra-vars" (workflow extra_vars should override artifacts), '
f'got my_var="{checker_extra_vars.get("my_var")}". '
f'checker node ancestor_artifacts={checker_node.ancestor_artifacts}'
)

View File

@@ -1,271 +0,0 @@
# Copyright (c) 2026 Ansible, Inc.
# All Rights Reserved.
"""Tests for analytics ship() function with mTLS authentication."""
import os
import tempfile
from unittest import mock
from django.test.utils import override_settings
from awx.main.analytics.core import ship, _get_cert_upload_url
class TestGetCertUploadUrl:
"""Test _get_cert_upload_url() helper function."""
def test_adds_cert_subdomain(self):
"""Test that 'cert.' is added to hostname."""
url = 'https://analytics.example.com/api/ingress/v1/upload'
result = _get_cert_upload_url(url)
assert result == 'https://cert.analytics.example.com/api/ingress/v1/upload'
def test_preserves_existing_cert_subdomain(self):
"""Test that existing 'cert.' subdomain is preserved."""
url = 'https://cert.analytics.example.com/api/ingress/v1/upload'
result = _get_cert_upload_url(url)
assert result == 'https://cert.analytics.example.com/api/ingress/v1/upload'
class TestShipMTLS:
"""Test ship() function's mTLS authentication path."""
def setup_method(self):
"""Create a temporary tarball for testing."""
self.temp_file = tempfile.NamedTemporaryFile(mode='wb', suffix='.tar.gz', delete=False)
self.temp_file.write(b'test tarball content')
self.temp_file.close()
self.tarball_path = self.temp_file.name
def teardown_method(self):
"""Clean up temporary tarball."""
if os.path.exists(self.tarball_path):
os.unlink(self.tarball_path)
@override_settings(
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
REDHAT_USERNAME='test_user',
REDHAT_PASSWORD='test_pass', # NOSONAR
AWX_TASK_ENV={},
)
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
@mock.patch('awx.main.analytics.core._temp_cert_files')
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
@mock.patch('awx.main.analytics.core.requests.Session')
def test_ship_with_mtls_success(self, mock_session_class, mock_get_cert, mock_temp_files, mock_headers):
"""Test successful upload with mTLS certificate authentication."""
# Mock headers to avoid database access
mock_headers.return_value = {'Content-Type': 'application/json'}
# Mock certificate retrieval
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
# Mock temp files context manager
mock_temp_files.return_value.__enter__.return_value = ('/tmp/cert.pem', '/tmp/key.pem')
mock_temp_files.return_value.__exit__.return_value = None
# Mock successful mTLS response
mock_response = mock.Mock()
mock_response.status_code = 200
mock_session = mock.Mock()
mock_session.headers = {}
mock_session.post.return_value = mock_response
mock_session_class.return_value = mock_session
result = ship(self.tarball_path)
assert result is True
mock_get_cert.assert_called_once()
mock_temp_files.assert_called_once_with('cert-pem-data', 'key-pem-data')
mock_session.post.assert_called_once()
# Verify cert URL is used (cert. subdomain added)
call_args = mock_session.post.call_args
assert call_args[0][0] == 'https://cert.analytics.example.com/api/ingress/v1/upload'
# Verify mTLS cert was used
call_kwargs = call_args[1]
assert call_kwargs['cert'] == ('/tmp/cert.pem', '/tmp/key.pem')
@override_settings(
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
REDHAT_USERNAME='test_user',
REDHAT_PASSWORD='test_pass', # NOSONAR
AWX_TASK_ENV={},
)
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
@mock.patch('awx.main.analytics.core.OIDCClient')
@mock.patch('awx.main.analytics.core._temp_cert_files')
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
@mock.patch('awx.main.analytics.core.requests.Session')
def test_ship_mtls_fallback_to_oidc_on_cert_failure(self, mock_session_class, mock_get_cert, mock_temp_files, mock_oidc_client, mock_headers):
"""Test fallback to OIDC auth when mTLS cert authentication fails."""
# Mock headers to avoid database access
mock_headers.return_value = {'Content-Type': 'application/json'}
# Mock certificate retrieval
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
# Mock temp files context manager
mock_temp_files.return_value.__enter__.return_value = ('/tmp/cert.pem', '/tmp/key.pem')
mock_temp_files.return_value.__exit__.return_value = None
# Mock failed mTLS response (401 Unauthorized)
mock_mtls_response = mock.Mock()
mock_mtls_response.status_code = 401
mock_session = mock.Mock()
mock_session.headers = {}
mock_session.post.return_value = mock_mtls_response
mock_session_class.return_value = mock_session
# Mock successful OIDC response
mock_oidc_response = mock.Mock()
mock_oidc_response.status_code = 200
mock_oidc_instance = mock.Mock()
mock_oidc_instance.make_request.return_value = mock_oidc_response
mock_oidc_client.return_value = mock_oidc_instance
result = ship(self.tarball_path)
assert result is True
# Both mTLS and OIDC should be attempted
assert mock_session.post.call_count == 1
mock_oidc_instance.make_request.assert_called_once()
# Verify mTLS used cert URL
mtls_call_args = mock_session.post.call_args
assert mtls_call_args[0][0] == 'https://cert.analytics.example.com/api/ingress/v1/upload'
# Verify OIDC used original URL
oidc_call_args = mock_oidc_instance.make_request.call_args
assert oidc_call_args[0][1] == 'https://analytics.example.com/api/ingress/v1/upload'
@override_settings(
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
REDHAT_USERNAME='test_user',
REDHAT_PASSWORD='test_pass', # NOSONAR
AWX_TASK_ENV={},
)
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
@mock.patch('awx.main.analytics.core._temp_cert_files')
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
@mock.patch('awx.main.analytics.core.OIDCClient')
@mock.patch('awx.main.analytics.core.requests.Session')
def test_ship_mtls_exception_fallback_to_oidc(self, mock_session_class, mock_oidc_client, mock_get_cert, mock_temp_files, mock_headers):
"""Test fallback to OIDC auth when mTLS raises an exception."""
# Mock headers to avoid database access
mock_headers.return_value = {'Content-Type': 'application/json'}
# Mock certificate retrieval
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
# Mock temp files context manager raising an exception
mock_temp_files.return_value.__enter__.side_effect = OSError('Temp file creation failed')
# Mock successful OIDC response
mock_oidc_response = mock.Mock()
mock_oidc_response.status_code = 200
mock_oidc_instance = mock.Mock()
mock_oidc_instance.make_request.return_value = mock_oidc_response
mock_oidc_client.return_value = mock_oidc_instance
mock_session = mock.Mock()
mock_session.headers = {}
mock_session_class.return_value = mock_session
result = ship(self.tarball_path)
assert result is True
# mTLS should fail, OIDC should succeed
mock_oidc_instance.make_request.assert_called_once()
@override_settings(
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
REDHAT_USERNAME='test_user',
REDHAT_PASSWORD='test_pass', # NOSONAR
AWX_TASK_ENV={},
)
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
@mock.patch('awx.main.analytics.core.OIDCClient')
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
@mock.patch('awx.main.analytics.core.requests.Session')
def test_ship_no_certificate_available(self, mock_session_class, mock_get_cert, mock_oidc_client, mock_headers):
"""Test ship() when no Candlepin certificate is available."""
# Mock headers to avoid database access
mock_headers.return_value = {'Content-Type': 'application/json'}
# Mock no certificate available
mock_get_cert.return_value = (None, None)
# Mock successful OIDC response
mock_oidc_response = mock.Mock()
mock_oidc_response.status_code = 200
mock_oidc_instance = mock.Mock()
mock_oidc_instance.make_request.return_value = mock_oidc_response
mock_oidc_client.return_value = mock_oidc_instance
mock_session = mock.Mock()
mock_session.headers = {}
mock_session_class.return_value = mock_session
result = ship(self.tarball_path)
assert result is True
# Should skip mTLS and go straight to OIDC
mock_oidc_instance.make_request.assert_called_once()
@override_settings(
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
REDHAT_USERNAME='test_user',
REDHAT_PASSWORD='test_pass', # NOSONAR
AWX_TASK_ENV={},
)
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
@mock.patch('awx.main.analytics.core.OIDCClient')
@mock.patch('awx.main.analytics.core._temp_cert_files')
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
@mock.patch('awx.main.analytics.core.requests.Session')
def test_ship_both_auth_methods_fail(self, mock_session_class, mock_get_cert, mock_temp_files, mock_oidc_client, mock_headers):
"""Test ship() when both mTLS and OIDC authentication fail."""
# Mock headers to avoid database access
mock_headers.return_value = {'Content-Type': 'application/json'}
# Mock certificate retrieval
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
# Mock temp files context manager
mock_temp_files.return_value.__enter__.return_value = ('/tmp/cert.pem', '/tmp/key.pem')
mock_temp_files.return_value.__exit__.return_value = None
# Mock failed mTLS response
mock_mtls_response = mock.Mock()
mock_mtls_response.status_code = 401
mock_session = mock.Mock()
mock_session.headers = {}
mock_session.post.return_value = mock_mtls_response
mock_session_class.return_value = mock_session
# Mock failed OIDC response
mock_oidc_response = mock.Mock()
mock_oidc_response.status_code = 403
mock_oidc_response.text = 'Forbidden'
mock_oidc_instance = mock.Mock()
mock_oidc_instance.make_request.return_value = mock_oidc_response
mock_oidc_client.return_value = mock_oidc_instance
result = ship(self.tarball_path)
assert result is False
mock_session.post.assert_called_once()
mock_oidc_instance.make_request.assert_called_once()

View File

@@ -39,7 +39,7 @@ def test_unified_job_detail_exclusive_fields():
For each type, assert that the only fields allowed to be exclusive to
detail view are the allowed types
"""
allowed_detail_fields = frozenset(('result_traceback', 'job_args', 'job_cwd', 'job_env', 'event_processing_finished', 'artifacts', 'extra_vars'))
allowed_detail_fields = frozenset(('result_traceback', 'job_args', 'job_cwd', 'job_env', 'event_processing_finished', 'artifacts'))
for cls in UnifiedJob.__subclasses__():
list_serializer = getattr(serializers, '{}ListSerializer'.format(cls.__name__))
detail_serializer = getattr(serializers, '{}Serializer'.format(cls.__name__))

View File

@@ -1,310 +0,0 @@
# Copyright (c) 2026 Ansible, Inc.
# All Rights Reserved.
"""Tests for candlepin_cert management command."""
from io import StringIO
from unittest import mock
import pytest
from django.core.management import call_command
from django.test.utils import override_settings
class TestCandlepinCertCommand:
"""Tests for candlepin_cert management command."""
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_registration_to_db')
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
@mock.patch('awx.main.management.commands.candlepin_cert.resolve_registration_credentials')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
@override_settings(
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
AWX_ANALYTICS_CANDLEPIN_CA=None,
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
)
def test_register_success(self, mock_fetch_cert, mock_resolve_creds, mock_client_class, mock_save_reg):
"""Test successful registration."""
# No existing cert
mock_fetch_cert.return_value = (None, None, None)
# Valid credentials
mock_resolve_creds.return_value = ('test_user', 'test_pass', 'test_org', 'install-uuid', None)
# Mock successful registration
mock_client = mock.Mock()
mock_client.register_consumer.return_value = ('cert-pem', 'key-pem', 'consumer-uuid')
mock_client_class.return_value = mock_client
# Mock successful save
mock_save_reg.return_value = True
out = StringIO()
call_command('candlepin_cert', 'register', stdout=out, stderr=StringIO())
output = out.getvalue()
assert 'Registered successfully' in output
assert 'consumer-uuid' in output
mock_client.register_consumer.assert_called_once_with('test_user', 'test_pass', 'test_org', install_uuid='install-uuid')
mock_save_reg.assert_called_once_with('cert-pem', 'key-pem', 'consumer-uuid')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
def test_register_already_registered_without_force(self, mock_fetch_cert):
"""Test registration fails when cert already exists and --force not provided."""
# Existing cert
mock_fetch_cert.return_value = ('existing-cert', 'existing-key', 'existing-uuid')
out = StringIO()
call_command('candlepin_cert', 'register', stdout=out, stderr=StringIO())
output = out.getvalue()
assert 'already stored' in output
assert '--force' in output
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_registration_to_db')
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
@mock.patch('awx.main.management.commands.candlepin_cert.resolve_registration_credentials')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
@override_settings(
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
AWX_ANALYTICS_CANDLEPIN_CA=None,
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
)
def test_register_with_force_flag(self, mock_fetch_cert, mock_resolve_creds, mock_client_class, mock_save_reg):
"""Test registration succeeds with --force even when cert exists."""
# Existing cert
mock_fetch_cert.return_value = ('existing-cert', 'existing-key', 'existing-uuid')
# Valid credentials
mock_resolve_creds.return_value = ('test_user', 'test_pass', 'test_org', 'install-uuid', None)
# Mock successful registration
mock_client = mock.Mock()
mock_client.register_consumer.return_value = ('new-cert-pem', 'new-key-pem', 'new-consumer-uuid')
mock_client_class.return_value = mock_client
# Mock successful save
mock_save_reg.return_value = True
out = StringIO()
call_command('candlepin_cert', 'register', '--force', stdout=out, stderr=StringIO())
output = out.getvalue()
assert 'Registered successfully' in output
mock_client.register_consumer.assert_called_once()
mock_save_reg.assert_called_once_with('new-cert-pem', 'new-key-pem', 'new-consumer-uuid')
@mock.patch('awx.main.management.commands.candlepin_cert.resolve_registration_credentials')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
def test_register_missing_credentials(self, mock_fetch_cert, mock_resolve_creds):
"""Test registration fails when credentials are missing."""
mock_fetch_cert.return_value = (None, None, None)
# Missing credentials
mock_resolve_creds.return_value = (None, None, None, None, ['username', 'password'])
err = StringIO()
with pytest.raises(SystemExit) as exc_info:
call_command('candlepin_cert', 'register', stderr=err)
assert exc_info.value.code == 1
error_output = err.getvalue()
assert 'Missing required value' in error_output
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_cert_to_db')
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
@override_settings(
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
AWX_ANALYTICS_CANDLEPIN_CA=None,
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
)
def test_renew_success(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class, mock_save_cert):
"""Test successful certificate renewal."""
# Existing cert
mock_fetch_cert.return_value = ('old-cert', 'old-key', 'consumer-uuid')
# Parse cert returns metadata
mock_parse_cert.side_effect = [
{'serial': '123', 'cn': 'test', 'not_after': '2026-06-01', 'days_remaining': 10}, # Current cert
{'serial': '456', 'cn': 'test', 'not_after': '2027-06-01', 'days_remaining': 365}, # Renewed cert
]
# Renewal needed
mock_needs_renewal.return_value = True
# Mock successful check-in and renewal
mock_client = mock.Mock()
mock_client.checkin.return_value = True
mock_client.regenerate_cert.return_value = ('new-cert', 'new-key')
mock_client_class.return_value = mock_client
mock_save_cert.return_value = True
out = StringIO()
call_command('candlepin_cert', 'renew', stdout=out, stderr=StringIO())
output = out.getvalue()
assert 'Check-in successful' in output
assert 'Certificate renewed successfully' in output
assert 'saved to database' in output
mock_client.checkin.assert_called_once_with('consumer-uuid', 'old-cert', 'old-key')
mock_client.regenerate_cert.assert_called_once()
mock_save_cert.assert_called_once_with('new-cert', 'new-key')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
def test_renew_no_cert_in_db(self, mock_fetch_cert):
"""Test renew fails when no certificate exists in database."""
mock_fetch_cert.return_value = (None, None, None)
err = StringIO()
with pytest.raises(SystemExit) as exc_info:
call_command('candlepin_cert', 'renew', stderr=err)
assert exc_info.value.code == 1
error_output = err.getvalue()
assert 'No Candlepin identity certificate found' in error_output
assert 'Run the register subcommand first' in error_output
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
@override_settings(
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
AWX_ANALYTICS_CANDLEPIN_CA=None,
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
)
def test_renew_not_needed(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class):
"""Test renew when certificate is still valid and renewal not needed."""
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
# Parse cert returns healthy cert
mock_parse_cert.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 200}
# Renewal not needed
mock_needs_renewal.return_value = False
# Mock successful check-in
mock_client = mock.Mock()
mock_client.checkin.return_value = True
mock_client_class.return_value = mock_client
out = StringIO()
call_command('candlepin_cert', 'renew', stdout=out, stderr=StringIO())
output = out.getvalue()
assert 'Check-in successful' in output
assert 'No renewal needed' in output
mock_client.checkin.assert_called_once()
mock_client.regenerate_cert.assert_not_called()
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_cert_to_db')
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
@override_settings(
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
AWX_ANALYTICS_CANDLEPIN_CA=None,
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
)
def test_renew_with_force_flag(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class, mock_save_cert):
"""Test renew --force renews even when not needed."""
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
# Parse cert
mock_parse_cert.side_effect = [
{'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 200}, # Current cert (healthy)
{'serial': '456', 'cn': 'test', 'not_after': '2027-06-01', 'days_remaining': 365}, # New cert
]
# Would not need renewal without --force
mock_needs_renewal.return_value = False
# Mock successful operations
mock_client = mock.Mock()
mock_client.checkin.return_value = True
mock_client.regenerate_cert.return_value = ('new-cert', 'new-key')
mock_client_class.return_value = mock_client
mock_save_cert.return_value = True
out = StringIO()
call_command('candlepin_cert', 'renew', '--force', stdout=out, stderr=StringIO())
output = out.getvalue()
assert 'forced via --force' in output
assert 'Certificate renewed successfully' in output
mock_client.regenerate_cert.assert_called_once()
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
@override_settings(
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
AWX_ANALYTICS_CANDLEPIN_CA=None,
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
)
def test_renew_checkin_failure(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class):
"""Test renew handles check-in failure gracefully."""
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
mock_parse_cert.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 100}
mock_needs_renewal.return_value = False # Not needed for renewal, just testing check-in failure
# Mock failed check-in
mock_client = mock.Mock()
mock_client.checkin.return_value = False
mock_client_class.return_value = mock_client
err = StringIO()
with pytest.raises(SystemExit) as exc_info:
call_command('candlepin_cert', 'renew', stderr=err)
assert exc_info.value.code == 1
error_output = err.getvalue()
assert 'Check-in with Candlepin failed' in error_output
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
@override_settings(
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
AWX_ANALYTICS_CANDLEPIN_CA=None,
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
)
def test_renew_regenerate_cert_failure(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class):
"""Test renew handles certificate regeneration failure."""
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
mock_parse_cert.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2026-06-01', 'days_remaining': 10}
mock_needs_renewal.return_value = True
# Mock successful check-in but failed regeneration
mock_client = mock.Mock()
mock_client.checkin.return_value = True
mock_client.regenerate_cert.side_effect = Exception('Certificate regeneration failed')
mock_client_class.return_value = mock_client
err = StringIO()
with pytest.raises(SystemExit) as exc_info:
call_command('candlepin_cert', 'renew', stderr=err)
assert exc_info.value.code == 1
error_output = err.getvalue()
assert 'Certificate renewal failed' in error_output

Some files were not shown because too many files have changed in this diff Show More