Compare commits

..

16 Commits

Author SHA1 Message Date
thedoubl3j
0377b3830b Update operator timeout
* updated the operator timeout to near healthy run time
2026-01-23 10:39:55 -05:00
Jake Jackson
331ae92475 Merge branch 'devel' into move_to_dispatcherd 2026-01-23 10:06:30 -05:00
Jake Jackson
e355df6cc6 Merge branch 'devel' into move_to_dispatcherd 2026-01-22 11:03:45 -05:00
thedoubl3j
806ef7c345 Fix attribute error in server logs
* on a secret hunt to find the hidden attribute error in the server logs
2026-01-20 16:08:46 -05:00
thedoubl3j
8acdd0cbf4 Fix imports and linter findings
* add back more missing things
2026-01-20 15:41:33 -05:00
thedoubl3j
381c7fdc5d Adjust heartbeat arg and more formatting
* fixed the call to cluster_node_heartbeat missing binder
* formatting/linter fixes
2026-01-20 15:21:23 -05:00
thedoubl3j
d75fcc13f6 Fix dispatcher run call and remove dispatch settin
* added back some code that was lost in the merge conflict
* remove dispatcher mock publish setting
2026-01-20 14:38:54 -05:00
thedoubl3j
bb8ecc5919 Add back hazmat for config and remove baseworker
* added back hazmat per @alancoding feedback around config
* removed baseworker completely and refactored it into the callback
  worker
2026-01-19 20:33:23 -05:00
thedoubl3j
1019ac0439 Update function comments 2026-01-19 20:30:41 -05:00
thedoubl3j
cddee29f23 More chainsaw work
* fixed imports and addressed clusternode heartbeat test
* took a chainsaw to task.py as well
2026-01-19 20:30:41 -05:00
thedoubl3j
3b896a00a9 Clean up imports and fix some tests
* removed unused imports
* adjusted test import to pull correct method
2026-01-19 20:30:41 -05:00
thedoubl3j
e386326498 Remove control and hazmat (squash this not done)
* moved status out and deleted control as no longer needed
* removed hazmat
2026-01-19 20:30:41 -05:00
thedoubl3j
5209bfcf82 add back auto_max_workers
* added back get_auto_max_workers into common utils
* formatting edits
2026-01-19 20:30:07 -05:00
thedoubl3j
ebd51cd074 Keep callback receiver working
* remove any code that is not used by the call back receiver
2026-01-19 20:26:04 -05:00
thedoubl3j
f9f4bf2d1a Add decorator
* moved to dispatcher decorator
* updated as many as I could find
2026-01-19 20:26:04 -05:00
thedoubl3j
e55578b64e WIP First pass
* started removing feature flags and adjusting logic
* WIP
2026-01-19 20:26:04 -05:00
157 changed files with 783 additions and 6351 deletions

View File

@@ -24,7 +24,7 @@ in as the first entry for your PR title.
##### STEPS TO REPRODUCE AND EXTRA INFO ##### ADDITIONAL INFORMATION
<!--- <!---
Include additional information to help people understand the change here. Include additional information to help people understand the change here.
For bugs that don't have a linked bug report, a step-by-step reproduction For bugs that don't have a linked bug report, a step-by-step reproduction

View File

@@ -45,45 +45,15 @@ jobs:
make docker-runner 2>&1 | tee schema-diff.txt make docker-runner 2>&1 | tee schema-diff.txt
exit ${PIPESTATUS[0]} exit ${PIPESTATUS[0]}
- name: Validate OpenAPI schema - name: Add schema diff to job summary
id: schema-validation
continue-on-error: true
run: |
AWX_DOCKER_ARGS='-e GITHUB_ACTIONS' \
AWX_DOCKER_CMD='make validate-openapi-schema' \
make docker-runner 2>&1 | tee schema-validation.txt
exit ${PIPESTATUS[0]}
- name: Add schema validation and diff to job summary
if: always() if: always()
# show text and if for some reason, it can't be generated, state that it can't be. # show text and if for some reason, it can't be generated, state that it can't be.
run: | run: |
echo "## API Schema Check Results" >> $GITHUB_STEP_SUMMARY echo "## API Schema Change Detection Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
# Show validation status
echo "### OpenAPI Validation" >> $GITHUB_STEP_SUMMARY
if [ -f schema-validation.txt ] && grep -q "✓ Schema is valid" schema-validation.txt; then
echo "✅ **Status:** PASSED - Schema is valid OpenAPI 3.0.3" >> $GITHUB_STEP_SUMMARY
else
echo "❌ **Status:** FAILED - Schema validation failed" >> $GITHUB_STEP_SUMMARY
if [ -f schema-validation.txt ]; then
echo "" >> $GITHUB_STEP_SUMMARY
echo "<details><summary>Validation errors</summary>" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
cat schema-validation.txt >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi
fi
echo "" >> $GITHUB_STEP_SUMMARY
# Show schema changes
echo "### Schema Changes" >> $GITHUB_STEP_SUMMARY
if [ -f schema-diff.txt ]; then if [ -f schema-diff.txt ]; then
if grep -q "^+" schema-diff.txt || grep -q "^-" schema-diff.txt; then if grep -q "^+" schema-diff.txt || grep -q "^-" schema-diff.txt; then
echo "**Changes detected** between this PR and the base branch" >> $GITHUB_STEP_SUMMARY echo "### Schema changes detected" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
# Truncate to first 1000 lines to stay under GitHub's 1MB summary limit # Truncate to first 1000 lines to stay under GitHub's 1MB summary limit
TOTAL_LINES=$(wc -l < schema-diff.txt) TOTAL_LINES=$(wc -l < schema-diff.txt)
@@ -95,8 +65,8 @@ jobs:
head -n 1000 schema-diff.txt >> $GITHUB_STEP_SUMMARY head -n 1000 schema-diff.txt >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY
else else
echo "No schema changes detected" >> $GITHUB_STEP_SUMMARY echo "### No schema changes detected" >> $GITHUB_STEP_SUMMARY
fi fi
else else
echo "Unable to generate schema diff" >> $GITHUB_STEP_SUMMARY echo "### Unable to generate schema diff" >> $GITHUB_STEP_SUMMARY
fi fi

View File

@@ -4,46 +4,14 @@ env:
LC_ALL: "C.UTF-8" # prevent ERROR: Ansible could not initialize the preferred locale: unsupported locale setting LC_ALL: "C.UTF-8" # prevent ERROR: Ansible could not initialize the preferred locale: unsupported locale setting
CI_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} CI_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DEV_DOCKER_OWNER: ${{ github.repository_owner }} DEV_DOCKER_OWNER: ${{ github.repository_owner }}
COMPOSE_TAG: ${{ github.base_ref || github.ref_name || 'devel' }} COMPOSE_TAG: ${{ github.base_ref || 'devel' }}
UPSTREAM_REPOSITORY_ID: 91594105 UPSTREAM_REPOSITORY_ID: 91594105
on: on:
pull_request: pull_request:
push: push:
branches: branches:
- devel # needed to publish code coverage post-merge - devel # needed to publish code coverage post-merge
schedule:
- cron: '0 12,18 * * 1-5'
workflow_dispatch: {}
jobs: jobs:
trigger-release-branches:
name: "Dispatch CI to release branches"
if: github.event_name == 'schedule'
runs-on: ubuntu-latest
permissions:
actions: write
steps:
- name: Trigger CI on release_4.6
id: dispatch_release_46
continue-on-error: true
run: gh workflow run ci.yml --ref release_4.6
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
- name: Trigger CI on stable-2.6
id: dispatch_stable_26
continue-on-error: true
run: gh workflow run ci.yml --ref stable-2.6
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
- name: Check dispatch results
if: steps.dispatch_release_46.outcome == 'failure' || steps.dispatch_stable_26.outcome == 'failure'
run: |
echo "One or more dispatches failed:"
echo " release_4.6: ${{ steps.dispatch_release_46.outcome }}"
echo " stable-2.6: ${{ steps.dispatch_stable_26.outcome }}"
exit 1
common-tests: common-tests:
name: ${{ matrix.tests.name }} name: ${{ matrix.tests.name }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -94,11 +62,7 @@ jobs:
run: | run: |
if [ -f "reports/coverage.xml" ]; then if [ -f "reports/coverage.xml" ]; then
sed -i '2i<!-- PR ${{ github.event.pull_request.number }} -->' reports/coverage.xml sed -i '2i<!-- PR ${{ github.event.pull_request.number }} -->' reports/coverage.xml
echo "Injected PR number ${{ github.event.pull_request.number }} into reports/coverage.xml" echo "Injected PR number ${{ github.event.pull_request.number }} into coverage.xml"
fi
if [ -f "awxkit/coverage.xml" ]; then
sed -i '2i<!-- PR ${{ github.event.pull_request.number }} -->' awxkit/coverage.xml
echo "Injected PR number ${{ github.event.pull_request.number }} into awxkit/coverage.xml"
fi fi
- name: Upload test coverage to Codecov - name: Upload test coverage to Codecov
@@ -145,32 +109,28 @@ jobs:
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.tests.name }}-artifacts name: ${{ matrix.tests.name }}-artifacts
path: | path: reports/coverage.xml
reports/coverage.xml
awxkit/coverage.xml
retention-days: 5 retention-days: 5
- name: >- - name: Upload awx jUnit test reports
Upload ${{
matrix.tests.coverage-upload-name || 'awx'
}} jUnit test reports to the unified dashboard
if: >- if: >-
!cancelled() !cancelled()
&& steps.make-run.outputs.test-result-files != '' && steps.make-run.outputs.test-result-files != ''
&& github.event_name == 'push' && github.event_name == 'push'
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id && env.UPSTREAM_REPOSITORY_ID == github.repository_id
&& github.ref_name == github.event.repository.default_branch && github.ref_name == github.event.repository.default_branch
uses: ansible/gh-action-record-test-results@3784db66a1b7fb3809999a7251c8a7203a7ffbe8 run: |
with: for junit_file in $(echo '${{ steps.make-run.outputs.test-result-files }}' | sed 's/,/ /')
aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }} do
http-auth-password: >- curl \
${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }} -v \
http-auth-username: >- --user "${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }}:${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }}" \
${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }} --form "xunit_xml=@${junit_file}" \
project-component-name: >- --form "component_name=${{ matrix.tests.coverage-upload-name || 'awx' }}" \
${{ matrix.tests.coverage-upload-name || 'awx' }} --form "git_commit_sha=${{ github.sha }}" \
test-result-files: >- --form "git_repository_url=https://github.com/${{ github.repository }}" \
${{ steps.make-run.outputs.test-result-files }} "${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}/api/results/upload/"
done
dev-env: dev-env:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -334,16 +294,18 @@ jobs:
&& github.event_name == 'push' && github.event_name == 'push'
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id && env.UPSTREAM_REPOSITORY_ID == github.repository_id
&& github.ref_name == github.event.repository.default_branch && github.ref_name == github.event.repository.default_branch
uses: ansible/gh-action-record-test-results@3784db66a1b7fb3809999a7251c8a7203a7ffbe8 run: |
with: for junit_file in $(echo '${{ steps.make-run.outputs.test-result-files }}' | sed 's/,/ /')
aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }} do
http-auth-password: >- curl \
${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }} -v \
http-auth-username: >- --user "${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }}:${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }}" \
${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }} --form "xunit_xml=@${junit_file}" \
project-component-name: awx --form "component_name=awx" \
test-result-files: >- --form "git_commit_sha=${{ github.sha }}" \
${{ steps.make-run.outputs.test-result-files }} --form "git_repository_url=https://github.com/${{ github.repository }}" \
"${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}/api/results/upload/"
done
collection-integration: collection-integration:
name: awx_collection integration name: awx_collection integration

View File

@@ -1,176 +0,0 @@
# Sync OpenAPI Spec on Merge
#
# This workflow runs when code is merged to the devel branch.
# It runs the dev environment to generate the OpenAPI spec, then syncs it to
# the central spec repository.
#
# FLOW: PR merged → push to branch → dev environment runs → spec synced to central repo
#
# NOTE: This is an inlined version for testing with private forks.
# Production version will use a reusable workflow from the org repos.
name: Sync OpenAPI Spec on Merge
env:
LC_ALL: "C.UTF-8"
DEV_DOCKER_OWNER: ${{ github.repository_owner }}
on:
push:
branches:
- devel
workflow_dispatch: # Allow manual triggering for testing
jobs:
sync-openapi-spec:
name: Sync OpenAPI spec to central repo
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Checkout Controller repository
uses: actions/checkout@v4
with:
show-progress: false
- name: Build awx_devel image to use for schema gen
uses: ./.github/actions/awx_devel_image
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
private-github-key: ${{ secrets.PRIVATE_GITHUB_KEY }}
- name: Generate API Schema
run: |
DEV_DOCKER_TAG_BASE=ghcr.io/${OWNER_LC} \
COMPOSE_TAG=${{ github.base_ref || github.ref_name }} \
docker run -u $(id -u) --rm -v ${{ github.workspace }}:/awx_devel/:Z \
--workdir=/awx_devel `make print-DEVEL_IMAGE_NAME` /start_tests.sh genschema
- name: Verify spec file exists
run: |
SPEC_FILE="./schema.json"
if [ ! -f "$SPEC_FILE" ]; then
echo "❌ Spec file not found at $SPEC_FILE"
echo "Contents of workspace:"
ls -la .
exit 1
fi
echo "✅ Found spec file at $SPEC_FILE"
- name: Checkout spec repo
id: checkout_spec_repo
continue-on-error: true
uses: actions/checkout@v4
with:
repository: ansible-automation-platform/aap-openapi-specs
ref: ${{ github.ref_name }}
path: spec-repo
token: ${{ secrets.OPENAPI_SPEC_SYNC_TOKEN }}
- name: Fail if branch doesn't exist
if: steps.checkout_spec_repo.outcome == 'failure'
run: |
echo "##[error]❌ Branch '${{ github.ref_name }}' does not exist in the central spec repository."
echo "##[error]Expected branch: ${{ github.ref_name }}"
echo "##[error]This branch must be created in the spec repo before specs can be synced."
exit 1
- name: Compare specs
id: compare
run: |
COMPONENT_SPEC="./schema.json"
SPEC_REPO_FILE="spec-repo/controller.json"
# Check if spec file exists in spec repo
if [ ! -f "$SPEC_REPO_FILE" ]; then
echo "Spec file doesn't exist in spec repo - will create new file"
echo "has_diff=true" >> $GITHUB_OUTPUT
echo "is_new_file=true" >> $GITHUB_OUTPUT
else
# Compare files
if diff -q "$COMPONENT_SPEC" "$SPEC_REPO_FILE" > /dev/null; then
echo "✅ No differences found - specs are identical"
echo "has_diff=false" >> $GITHUB_OUTPUT
else
echo "📝 Differences found - spec has changed"
echo "has_diff=true" >> $GITHUB_OUTPUT
echo "is_new_file=false" >> $GITHUB_OUTPUT
fi
fi
- name: Update spec file
if: steps.compare.outputs.has_diff == 'true'
run: |
cp "./schema.json" "spec-repo/controller.json"
echo "✅ Updated spec-repo/controller.json"
- name: Create PR in spec repo
if: steps.compare.outputs.has_diff == 'true'
working-directory: spec-repo
env:
GH_TOKEN: ${{ secrets.OPENAPI_SPEC_SYNC_TOKEN }}
COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
run: |
# Configure git
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Create branch for PR
SHORT_SHA="${{ github.sha }}"
SHORT_SHA="${SHORT_SHA:0:7}"
BRANCH_NAME="update-Controller-${{ github.ref_name }}-${SHORT_SHA}"
git checkout -b "$BRANCH_NAME"
# Add and commit changes
git add "controller.json"
if [ "${{ steps.compare.outputs.is_new_file }}" == "true" ]; then
COMMIT_MSG="Add Controller OpenAPI spec for ${{ github.ref_name }}"
else
COMMIT_MSG="Update Controller OpenAPI spec for ${{ github.ref_name }}"
fi
git commit -m "$COMMIT_MSG
Synced from ${{ github.repository }}@${{ github.sha }}
Source branch: ${{ github.ref_name }}
Co-Authored-By: github-actions[bot] <github-actions[bot]@users.noreply.github.com>"
# Push branch
git push origin "$BRANCH_NAME"
# Create PR
PR_TITLE="[${{ github.ref_name }}] Update Controller spec from merged commit"
PR_BODY="## Summary
Automated OpenAPI spec sync from component repository merge.
**Source:** ${{ github.repository }}@${{ github.sha }}
**Branch:** \`${{ github.ref_name }}\`
**Component:** \`Controller\`
**Spec File:** \`controller.json\`
## Changes
$(if [ "${{ steps.compare.outputs.is_new_file }}" == "true" ]; then echo "- 🆕 New spec file created"; else echo "- 📝 Spec file updated with latest changes"; fi)
## Source Commit
\`\`\`
${COMMIT_MESSAGE}
\`\`\`
---
🤖 This PR was automatically generated by the OpenAPI spec sync workflow."
gh pr create \
--title "$PR_TITLE" \
--body "$PR_BODY" \
--base "${{ github.ref_name }}" \
--head "$BRANCH_NAME"
echo "✅ Created PR in spec repo"
- name: Report results
if: always()
run: |
if [ "${{ steps.compare.outputs.has_diff }}" == "true" ]; then
echo "📝 Spec sync completed - PR created in spec repo"
else
echo "✅ Spec sync completed - no changes needed"
fi

View File

@@ -1,6 +1,6 @@
-include awx/ui/Makefile -include awx/ui/Makefile
PYTHON := $(notdir $(shell for i in python3.12 python3.11 python3; do command -v $$i; done|sed 1q)) PYTHON := $(notdir $(shell for i in python3.12 python3; do command -v $$i; done|sed 1q))
SHELL := bash SHELL := bash
DOCKER_COMPOSE ?= docker compose DOCKER_COMPOSE ?= docker compose
OFFICIAL ?= no OFFICIAL ?= no
@@ -79,7 +79,7 @@ RECEPTOR_IMAGE ?= quay.io/ansible/receptor:devel
SRC_ONLY_PKGS ?= cffi,pycparser,psycopg,twilio SRC_ONLY_PKGS ?= cffi,pycparser,psycopg,twilio
# These should be upgraded in the AWX and Ansible venv before attempting # These should be upgraded in the AWX and Ansible venv before attempting
# to install the actual requirements # to install the actual requirements
VENV_BOOTSTRAP ?= pip==25.3 setuptools==80.9.0 setuptools_scm[toml]==9.2.2 wheel==0.46.3 cython==3.1.3 VENV_BOOTSTRAP ?= pip==25.3 setuptools==80.9.0 setuptools_scm[toml]==9.2.2 wheel==0.45.1 cython==3.1.3
NAME ?= awx NAME ?= awx
@@ -289,7 +289,7 @@ dispatcher:
@if [ "$(VENV_BASE)" ]; then \ @if [ "$(VENV_BASE)" ]; then \
. $(VENV_BASE)/awx/bin/activate; \ . $(VENV_BASE)/awx/bin/activate; \
fi; \ fi; \
$(PYTHON) manage.py dispatcherd $(PYTHON) manage.py run_dispatcher
## Run to start the zeromq callback receiver ## Run to start the zeromq callback receiver
receiver: receiver:
@@ -579,10 +579,6 @@ detect-schema-change: genschema
# diff exits with 1 when files differ - capture but don't fail # diff exits with 1 when files differ - capture but don't fail
-diff -u -b reference-schema.json schema.json -diff -u -b reference-schema.json schema.json
validate-openapi-schema: genschema
@echo "Validating OpenAPI schema from schema.json..."
@python3 -c "from openapi_spec_validator import validate; import json; spec = json.load(open('schema.json')); validate(spec); print('✓ OpenAPI Schema is valid!')"
docker-compose-clean: awx/projects docker-compose-clean: awx/projects
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml rm -sf $(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml rm -sf

View File

@@ -89,7 +89,7 @@ class DeprecatedCredentialField(serializers.IntegerField):
def to_internal_value(self, pk): def to_internal_value(self, pk):
try: try:
pk = int(pk) pk = int(pk)
except (ValueError, TypeError): except ValueError:
self.fail('invalid') self.fail('invalid')
try: try:
Credential.objects.get(pk=pk) Credential.objects.get(pk=pk)

View File

@@ -131,14 +131,8 @@ class LoggedLoginView(auth_views.LoginView):
class LoggedLogoutView(auth_views.LogoutView): class LoggedLogoutView(auth_views.LogoutView):
# Override http_method_names to allow GET requests (Django 5.2+ defaults to POST only)
http_method_names = ["get", "post", "options"]
success_url_allowed_hosts = set(settings.LOGOUT_ALLOWED_HOSTS.split(",")) if settings.LOGOUT_ALLOWED_HOSTS else set() success_url_allowed_hosts = set(settings.LOGOUT_ALLOWED_HOSTS.split(",")) if settings.LOGOUT_ALLOWED_HOSTS else set()
def get(self, request, *args, **kwargs):
"""Handle GET requests for logout (for backward compatibility)."""
return self.post(request, *args, **kwargs)
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
if is_proxied_request(): if is_proxied_request():
# 1) We intentionally don't obey ?next= here, just always redirect to platform login # 1) We intentionally don't obey ?next= here, just always redirect to platform login

View File

@@ -111,7 +111,7 @@ class UnifiedJobEventPagination(Pagination):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.use_limit_paginator = False self.use_limit_paginator = False
self.limit_pagination = LimitPagination() self.limit_pagination = LimitPagination()
super().__init__(*args, **kwargs) return super().__init__(*args, **kwargs)
def paginate_queryset(self, queryset, request, view=None): def paginate_queryset(self, queryset, request, view=None):
if 'limit' in request.query_params: if 'limit' in request.query_params:

View File

@@ -9,50 +9,6 @@ from drf_spectacular.views import (
) )
def filter_credential_type_schema(
result,
generator, # NOSONAR
request, # NOSONAR
public, # NOSONAR
):
"""
Postprocessing hook to filter CredentialType kind enum values.
For CredentialTypeRequest and PatchedCredentialTypeRequest schemas (POST/PUT/PATCH),
filter the 'kind' enum to only show 'cloud' and 'net' values.
This ensures the OpenAPI schema accurately reflects that only 'cloud' and 'net'
credential types can be created or modified via the API, matching the validation
in CredentialTypeSerializer.validate().
Args:
result: The OpenAPI schema dict to be modified
generator, request, public: Required by drf-spectacular interface (unused)
Returns:
The modified OpenAPI schema dict
"""
schemas = result.get('components', {}).get('schemas', {})
# Filter CredentialTypeRequest (POST/PUT) - field is required
if 'CredentialTypeRequest' in schemas:
kind_prop = schemas['CredentialTypeRequest'].get('properties', {}).get('kind', {})
if 'enum' in kind_prop:
# Filter to only cloud and net (no None - field is required)
kind_prop['enum'] = ['cloud', 'net']
kind_prop['description'] = "* `cloud` - Cloud\\n* `net` - Network"
# Filter PatchedCredentialTypeRequest (PATCH) - field is optional
if 'PatchedCredentialTypeRequest' in schemas:
kind_prop = schemas['PatchedCredentialTypeRequest'].get('properties', {}).get('kind', {})
if 'enum' in kind_prop:
# Filter to only cloud and net (None allowed - field can be omitted in PATCH)
kind_prop['enum'] = ['cloud', 'net', None]
kind_prop['description'] = "* `cloud` - Cloud\\n* `net` - Network"
return result
class CustomAutoSchema(AutoSchema): class CustomAutoSchema(AutoSchema):
"""Custom AutoSchema to add swagger_topic to tags and handle deprecated endpoints.""" """Custom AutoSchema to add swagger_topic to tags and handle deprecated endpoints."""

View File

@@ -1230,7 +1230,7 @@ class OrganizationSerializer(BaseSerializer, OpaQueryPathMixin):
# to a team. This provides a hint to the ui so it can know to not # to a team. This provides a hint to the ui so it can know to not
# display these roles for team role selection. # display these roles for team role selection.
for key in ('admin_role', 'member_role'): for key in ('admin_role', 'member_role'):
if summary_dict and key in summary_dict.get('object_roles', {}): if key in summary_dict.get('object_roles', {}):
summary_dict['object_roles'][key]['user_only'] = True summary_dict['object_roles'][key]['user_only'] = True
return summary_dict return summary_dict
@@ -2165,13 +2165,13 @@ class BulkHostDeleteSerializer(serializers.Serializer):
attrs['hosts_data'] = attrs['host_qs'].values() attrs['hosts_data'] = attrs['host_qs'].values()
if len(attrs['host_qs']) == 0: if len(attrs['host_qs']) == 0:
error_hosts = dict.fromkeys(attrs['hosts'], "Hosts do not exist or you lack permission to delete it") error_hosts = {host: "Hosts do not exist or you lack permission to delete it" for host in attrs['hosts']}
raise serializers.ValidationError({'hosts': error_hosts}) raise serializers.ValidationError({'hosts': error_hosts})
if len(attrs['host_qs']) < len(attrs['hosts']): if len(attrs['host_qs']) < len(attrs['hosts']):
hosts_exists = [host['id'] for host in attrs['hosts_data']] hosts_exists = [host['id'] for host in attrs['hosts_data']]
failed_hosts = list(set(attrs['hosts']).difference(hosts_exists)) failed_hosts = list(set(attrs['hosts']).difference(hosts_exists))
error_hosts = dict.fromkeys(failed_hosts, "Hosts do not exist or you lack permission to delete it") error_hosts = {host: "Hosts do not exist or you lack permission to delete it" for host in failed_hosts}
raise serializers.ValidationError({'hosts': error_hosts}) raise serializers.ValidationError({'hosts': error_hosts})
# Getting all inventories that the hosts can be in # Getting all inventories that the hosts can be in
@@ -3527,7 +3527,7 @@ class JobRelaunchSerializer(BaseSerializer):
choices=NEW_JOB_TYPE_CHOICES, choices=NEW_JOB_TYPE_CHOICES,
write_only=True, write_only=True,
) )
credential_passwords = VerbatimField(required=False, write_only=True) credential_passwords = VerbatimField(required=True, write_only=True)
class Meta: class Meta:
model = Job model = Job

View File

@@ -1,6 +1,6 @@
{% if content_only %}<div class="nocode ansi_fore ansi_back{% if dark %} ansi_dark{% endif %}">{% else %} {% if content_only %}<div class="nocode ansi_fore ansi_back{% if dark %} ansi_dark{% endif %}">{% else %}
<!DOCTYPE HTML> <!DOCTYPE HTML>
<html lang="en"> <html>
<head> <head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"> <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<title>{{ title }}</title> <title>{{ title }}</title>

View File

@@ -52,7 +52,6 @@ from ansi2html import Ansi2HTMLConverter
from datetime import timezone as dt_timezone from datetime import timezone as dt_timezone
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
from drf_spectacular.utils import extend_schema_view, extend_schema
# django-ansible-base # django-ansible-base
from ansible_base.lib.utils.requests import get_remote_hosts from ansible_base.lib.utils.requests import get_remote_hosts
@@ -379,10 +378,6 @@ class DashboardJobsGraphView(APIView):
class InstanceList(ListCreateAPIView): class InstanceList(ListCreateAPIView):
"""
Creates an instance if used on a Kubernetes or OpenShift deployment of Ansible Automation Platform.
"""
name = _("Instances") name = _("Instances")
model = models.Instance model = models.Instance
serializer_class = serializers.InstanceSerializer serializer_class = serializers.InstanceSerializer
@@ -1459,7 +1454,7 @@ class CredentialList(ListCreateAPIView):
@extend_schema_if_available( @extend_schema_if_available(
extensions={ extensions={
"x-ai-description": "Create a new credential. The `inputs` field contain type-specific input fields. The required fields depend on related `credential_type`. Use GET /v2/credential_types/{id}/ (tool name: controller.credential_types_retrieve) and inspect `inputs` field for the specific credential type's expected schema. The fields `user` and `team` are deprecated and should not be included in the payload." "x-ai-description": "Create a new credential. The `inputs` field contain type-specific input fields. The required fields depend on related `credential_type`. Use GET /v2/credential_types/{id}/ (tool name: controller.credential_types_retrieve) and inspect `inputs` field for the specific credential type's expected schema."
} }
) )
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
@@ -1608,11 +1603,7 @@ class CredentialExternalTest(SubDetailAPIView):
obj_permission_type = 'use' obj_permission_type = 'use'
resource_purpose = 'test external credential' resource_purpose = 'test external credential'
@extend_schema_if_available(extensions={"x-ai-description": """Test update the input values and metadata of an external credential. @extend_schema_if_available(extensions={"x-ai-description": "Test update the input values and metadata of an external credential"})
This endpoint supports testing credentials that connect to external secret management systems
such as CyberArk AIM, CyberArk Conjur, HashiCorp Vault, AWS Secrets Manager, Azure Key Vault,
Centrify Vault, Thycotic DevOps Secrets Vault, and GitHub App Installation Access Token Lookup.
It does not support standard credential types such as Machine, SCM, and Cloud."""})
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
obj = self.get_object() obj = self.get_object()
backend_kwargs = {} backend_kwargs = {}
@@ -1626,16 +1617,13 @@ class CredentialExternalTest(SubDetailAPIView):
with set_environ(**settings.AWX_TASK_ENV): with set_environ(**settings.AWX_TASK_ENV):
obj.credential_type.plugin.backend(**backend_kwargs) obj.credential_type.plugin.backend(**backend_kwargs)
return Response({}, status=status.HTTP_202_ACCEPTED) return Response({}, status=status.HTTP_202_ACCEPTED)
except requests.exceptions.HTTPError: except requests.exceptions.HTTPError as exc:
message = """Test operation is not supported for credential type {}. message = 'HTTP {}'.format(exc.response.status_code)
This endpoint only supports credentials that connect to return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
external secret management systems such as CyberArk, HashiCorp
Vault, or cloud-based secret managers.""".format(obj.credential_type.kind)
return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc: except Exception as exc:
message = exc.__class__.__name__ message = exc.__class__.__name__
exc_args = getattr(exc, 'args', []) args = getattr(exc, 'args', [])
for a in exc_args: for a in args:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason) message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
@@ -1693,8 +1681,8 @@ class CredentialTypeExternalTest(SubDetailAPIView):
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc: except Exception as exc:
message = exc.__class__.__name__ message = exc.__class__.__name__
args_exc = getattr(exc, 'args', []) args = getattr(exc, 'args', [])
for a in args_exc: for a in args:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason) message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
@@ -2481,11 +2469,6 @@ class JobTemplateDetail(RelatedJobsPreventDeleteMixin, RetrieveUpdateDestroyAPIV
resource_purpose = 'job template detail' resource_purpose = 'job template detail'
@extend_schema_view(
retrieve=extend_schema(
extensions={'x-ai-description': 'List job template launch criteria'},
)
)
class JobTemplateLaunch(RetrieveAPIView): class JobTemplateLaunch(RetrieveAPIView):
model = models.JobTemplate model = models.JobTemplate
obj_permission_type = 'start' obj_permission_type = 'start'
@@ -2494,9 +2477,6 @@ class JobTemplateLaunch(RetrieveAPIView):
resource_purpose = 'launch a job from a job template' resource_purpose = 'launch a job from a job template'
def update_raw_data(self, data): def update_raw_data(self, data):
"""
Use the ID of a job template to retrieve its launch details.
"""
try: try:
obj = self.get_object() obj = self.get_object()
except PermissionDenied: except PermissionDenied:
@@ -3330,11 +3310,6 @@ class WorkflowJobTemplateLabelList(JobTemplateLabelList):
resource_purpose = 'labels of a workflow job template' resource_purpose = 'labels of a workflow job template'
@extend_schema_view(
retrieve=extend_schema(
extensions={'x-ai-description': 'List workflow job template launch criteria.'},
)
)
class WorkflowJobTemplateLaunch(RetrieveAPIView): class WorkflowJobTemplateLaunch(RetrieveAPIView):
model = models.WorkflowJobTemplate model = models.WorkflowJobTemplate
obj_permission_type = 'start' obj_permission_type = 'start'
@@ -3343,9 +3318,6 @@ class WorkflowJobTemplateLaunch(RetrieveAPIView):
resource_purpose = 'launch a workflow job from a workflow job template' resource_purpose = 'launch a workflow job from a workflow job template'
def update_raw_data(self, data): def update_raw_data(self, data):
"""
Use the ID of a workflow job template to retrieve its launch details.
"""
try: try:
obj = self.get_object() obj = self.get_object()
except PermissionDenied: except PermissionDenied:
@@ -3738,11 +3710,6 @@ class JobCancel(GenericCancelView):
return super().post(request, *args, **kwargs) return super().post(request, *args, **kwargs)
@extend_schema_view(
retrieve=extend_schema(
extensions={'x-ai-description': 'List job relaunch criteria'},
)
)
class JobRelaunch(RetrieveAPIView): class JobRelaunch(RetrieveAPIView):
model = models.Job model = models.Job
obj_permission_type = 'start' obj_permission_type = 'start'
@@ -3750,7 +3717,6 @@ class JobRelaunch(RetrieveAPIView):
resource_purpose = 'relaunch a job' resource_purpose = 'relaunch a job'
def update_raw_data(self, data): def update_raw_data(self, data):
"""Use the ID of a job to retrieve data on retry attempts and necessary passwords."""
data = super(JobRelaunch, self).update_raw_data(data) data = super(JobRelaunch, self).update_raw_data(data)
try: try:
obj = self.get_object() obj = self.get_object()

View File

@@ -25,6 +25,7 @@ import requests
from ansible_base.lib.utils.schema import extend_schema_if_available from ansible_base.lib.utils.schema import extend_schema_if_available
from awx import MODE
from awx.api.generics import APIView from awx.api.generics import APIView
from awx.conf.registry import settings_registry from awx.conf.registry import settings_registry
from awx.main.analytics import all_collectors from awx.main.analytics import all_collectors
@@ -32,7 +33,7 @@ from awx.main.ha import is_ha_environment
from awx.main.tasks.system import clear_setting_cache from awx.main.tasks.system import clear_setting_cache
from awx.main.utils import get_awx_version, get_custom_venv_choices from awx.main.utils import get_awx_version, get_custom_venv_choices
from awx.main.utils.licensing import validate_entitlement_manifest from awx.main.utils.licensing import validate_entitlement_manifest
from awx.api.versioning import URLPathVersioning, reverse from awx.api.versioning import URLPathVersioning, reverse, drf_reverse
from awx.main.constants import PRIVILEGE_ESCALATION_METHODS from awx.main.constants import PRIVILEGE_ESCALATION_METHODS
from awx.main.models import Project, Organization, Instance, InstanceGroup, JobTemplate from awx.main.models import Project, Organization, Instance, InstanceGroup, JobTemplate
from awx.main.utils import set_environ from awx.main.utils import set_environ
@@ -61,6 +62,8 @@ class ApiRootView(APIView):
data['custom_logo'] = settings.CUSTOM_LOGO data['custom_logo'] = settings.CUSTOM_LOGO
data['custom_login_info'] = settings.CUSTOM_LOGIN_INFO data['custom_login_info'] = settings.CUSTOM_LOGIN_INFO
data['login_redirect_override'] = settings.LOGIN_REDIRECT_OVERRIDE data['login_redirect_override'] = settings.LOGIN_REDIRECT_OVERRIDE
if MODE == 'development':
data['docs'] = drf_reverse('api:schema-swagger-ui')
return Response(data) return Response(data)

View File

@@ -133,7 +133,7 @@ class WebhookReceiverBase(APIView):
@csrf_exempt @csrf_exempt
@extend_schema_if_available(extensions={"x-ai-description": "Receive a webhook event and trigger a job"}) @extend_schema_if_available(extensions={"x-ai-description": "Receive a webhook event and trigger a job"})
def post(self, request, *args, **kwargs_in): def post(self, request, *args, **kwargs):
# Ensure that the full contents of the request are captured for multiple uses. # Ensure that the full contents of the request are captured for multiple uses.
request.body request.body

View File

@@ -1,41 +0,0 @@
import http.client
import socket
import urllib.error
import urllib.request
import logging
from django.conf import settings
logger = logging.getLogger(__name__)
def get_dispatcherd_metrics(request):
metrics_cfg = settings.METRICS_SUBSYSTEM_CONFIG.get('server', {}).get(settings.METRICS_SERVICE_DISPATCHER, {})
host = metrics_cfg.get('host', 'localhost')
port = metrics_cfg.get('port', 8015)
metrics_filter = []
if request is not None and hasattr(request, "query_params"):
try:
nodes_filter = request.query_params.getlist("node")
except Exception:
nodes_filter = []
if nodes_filter and settings.CLUSTER_HOST_ID not in nodes_filter:
return ''
try:
metrics_filter = request.query_params.getlist("metric")
except Exception:
metrics_filter = []
if metrics_filter:
# Right now we have no way of filtering the dispatcherd metrics
# so just avoid getting in the way if another metric is filtered for
return ''
url = f"http://{host}:{port}/metrics"
try:
with urllib.request.urlopen(url, timeout=1.0) as response:
payload = response.read()
if not payload:
return ''
return payload.decode('utf-8')
except (urllib.error.URLError, UnicodeError, socket.timeout, TimeoutError, http.client.HTTPException) as exc:
logger.debug(f"Failed to collect dispatcherd metrics from {url}: {exc}")
return ''

View File

@@ -15,7 +15,6 @@ from rest_framework.request import Request
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
from awx.main.utils import is_testing from awx.main.utils import is_testing
from awx.main.utils.redis import get_redis_client from awx.main.utils.redis import get_redis_client
from .dispatcherd_metrics import get_dispatcherd_metrics
root_key = settings.SUBSYSTEM_METRICS_REDIS_KEY_PREFIX root_key = settings.SUBSYSTEM_METRICS_REDIS_KEY_PREFIX
logger = logging.getLogger('awx.main.analytics') logger = logging.getLogger('awx.main.analytics')
@@ -399,6 +398,11 @@ class DispatcherMetrics(Metrics):
SetFloatM('workflow_manager_recorded_timestamp', 'Unix timestamp when metrics were last recorded'), SetFloatM('workflow_manager_recorded_timestamp', 'Unix timestamp when metrics were last recorded'),
SetFloatM('workflow_manager_spawn_workflow_graph_jobs_seconds', 'Time spent spawning workflow tasks'), SetFloatM('workflow_manager_spawn_workflow_graph_jobs_seconds', 'Time spent spawning workflow tasks'),
SetFloatM('workflow_manager_get_tasks_seconds', 'Time spent loading workflow tasks from db'), SetFloatM('workflow_manager_get_tasks_seconds', 'Time spent loading workflow tasks from db'),
# dispatcher subsystem metrics
SetIntM('dispatcher_pool_scale_up_events', 'Number of times local dispatcher scaled up a worker since startup'),
SetIntM('dispatcher_pool_active_task_count', 'Number of active tasks in the worker pool when last task was submitted'),
SetIntM('dispatcher_pool_max_worker_count', 'Highest number of workers in worker pool in last collection interval, about 20s'),
SetFloatM('dispatcher_availability', 'Fraction of time (in last collection interval) dispatcher was able to receive messages'),
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -426,12 +430,8 @@ class CallbackReceiverMetrics(Metrics):
def metrics(request): def metrics(request):
output_text = '' output_text = ''
output_text += DispatcherMetrics().generate_metrics(request) for m in [DispatcherMetrics(), CallbackReceiverMetrics()]:
output_text += CallbackReceiverMetrics().generate_metrics(request) output_text += m.generate_metrics(request)
dispatcherd_metrics = get_dispatcherd_metrics(request)
if dispatcherd_metrics:
output_text += dispatcherd_metrics
return output_text return output_text
@@ -481,6 +481,13 @@ class CallbackReceiverMetricsServer(MetricsServer):
super().__init__(settings.METRICS_SERVICE_CALLBACK_RECEIVER, registry) super().__init__(settings.METRICS_SERVICE_CALLBACK_RECEIVER, registry)
class DispatcherMetricsServer(MetricsServer):
def __init__(self):
registry = CollectorRegistry(auto_describe=True)
registry.register(CustomToPrometheusMetricsCollector(DispatcherMetrics(metrics_have_changed=False)))
super().__init__(settings.METRICS_SERVICE_DISPATCHER, registry)
class WebsocketsMetricsServer(MetricsServer): class WebsocketsMetricsServer(MetricsServer):
def __init__(self): def __init__(self):
registry = CollectorRegistry(auto_describe=True) registry = CollectorRegistry(auto_describe=True)

View File

@@ -82,7 +82,7 @@ class MainConfig(AppConfig):
def configure_dispatcherd(self): def configure_dispatcherd(self):
"""This implements the default configuration for dispatcherd """This implements the default configuration for dispatcherd
If running the tasking service like awx-manage dispatcherd, If running the tasking service like awx-manage run_dispatcher,
some additional config will be applied on top of this. some additional config will be applied on top of this.
This configuration provides the minimum such that code can submit This configuration provides the minimum such that code can submit
tasks to pg_notify to run those tasks. tasks to pg_notify to run those tasks.

View File

@@ -11,7 +11,6 @@ __all__ = [
'CAN_CANCEL', 'CAN_CANCEL',
'ACTIVE_STATES', 'ACTIVE_STATES',
'STANDARD_INVENTORY_UPDATE_ENV', 'STANDARD_INVENTORY_UPDATE_ENV',
'OIDC_CREDENTIAL_TYPE_NAMESPACES',
] ]
PRIVILEGE_ESCALATION_METHODS = [ PRIVILEGE_ESCALATION_METHODS = [
@@ -141,6 +140,3 @@ org_role_to_permission = {
'execution_environment_admin_role': 'add_executionenvironment', 'execution_environment_admin_role': 'add_executionenvironment',
'auditor_role': 'view_project', # TODO: also doesnt really work 'auditor_role': 'view_project', # TODO: also doesnt really work
} }
# OIDC credential type namespaces for feature flag filtering
OIDC_CREDENTIAL_TYPE_NAMESPACES = ['hashivault-kv-oidc', 'hashivault-ssh-oidc']

View File

@@ -27,14 +27,10 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
"pool_kwargs": { "pool_kwargs": {
"min_workers": settings.JOB_EVENT_WORKERS, "min_workers": settings.JOB_EVENT_WORKERS,
"max_workers": max_workers, "max_workers": max_workers,
# This must be less than max_workers to make sense, which is usually 4
# With reserve of 1, after a burst of tasks, load needs to down to 4-1=3
# before we return to min_workers
"scaledown_reserve": 1,
}, },
"main_kwargs": {"node_id": settings.CLUSTER_HOST_ID}, "main_kwargs": {"node_id": settings.CLUSTER_HOST_ID},
"process_manager_cls": "ForkServerManager", "process_manager_cls": "ForkServerManager",
"process_manager_kwargs": {"preload_modules": ['awx.main.dispatch.prefork']}, "process_manager_kwargs": {"preload_modules": ['awx.main.dispatch.hazmat']},
}, },
"brokers": {}, "brokers": {},
"publish": {}, "publish": {},
@@ -42,8 +38,8 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
} }
if mock_publish: if mock_publish:
config["brokers"]["dispatcherd.testing.brokers.noop"] = {} config["brokers"]["noop"] = {}
config["publish"]["default_broker"] = "dispatcherd.testing.brokers.noop" config["publish"]["default_broker"] = "noop"
else: else:
config["brokers"]["pg_notify"] = { config["brokers"]["pg_notify"] = {
"config": get_pg_notify_params(), "config": get_pg_notify_params(),
@@ -60,11 +56,5 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
} }
config["brokers"]["pg_notify"]["channels"] = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()] config["brokers"]["pg_notify"]["channels"] = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()]
metrics_cfg = settings.METRICS_SUBSYSTEM_CONFIG.get('server', {}).get(settings.METRICS_SERVICE_DISPATCHER)
if metrics_cfg:
config["service"]["metrics_kwargs"] = {
"host": metrics_cfg.get("host", "localhost"),
"port": metrics_cfg.get("port", 8015),
}
return config return config

View File

@@ -18,7 +18,7 @@ django.setup() # noqa
from django.conf import settings from django.conf import settings
# Preload all periodic tasks so their imports will be in shared memory # Preload all periodic tasks so their imports will be in shared memory
for name, options in settings.DISPATCHER_SCHEDULE.items(): for name, options in settings.CELERYBEAT_SCHEDULE.items():
resolve_callable(options['task']) resolve_callable(options['task'])

View File

@@ -1,4 +1,6 @@
import logging import logging
import os
import time
from multiprocessing import Process from multiprocessing import Process
@@ -13,12 +15,13 @@ class PoolWorker(object):
""" """
A simple wrapper around a multiprocessing.Process that tracks a worker child process. A simple wrapper around a multiprocessing.Process that tracks a worker child process.
The worker process runs the provided target function. The worker process runs the provided target function and tracks its creation time.
""" """
def __init__(self, target, args): def __init__(self, target, args, **kwargs):
self.process = Process(target=target, args=args) self.process = Process(target=target, args=args)
self.process.daemon = True self.process.daemon = True
self.creation_time = time.monotonic()
def start(self): def start(self):
self.process.start() self.process.start()
@@ -35,20 +38,44 @@ class WorkerPool(object):
pool = WorkerPool(workers_num=4) # spawn four worker processes pool = WorkerPool(workers_num=4) # spawn four worker processes
""" """
def __init__(self, workers_num=None): pool_cls = PoolWorker
self.workers_num = workers_num or settings.JOB_EVENT_WORKERS debug_meta = ''
def init_workers(self, target): def __init__(self, workers_num=None):
self.name = settings.CLUSTER_HOST_ID
self.pid = os.getpid()
self.workers_num = workers_num or settings.JOB_EVENT_WORKERS
self.workers = []
def __len__(self):
return len(self.workers)
def init_workers(self, target, *target_args):
self.target = target
self.target_args = target_args
for idx in range(self.workers_num): for idx in range(self.workers_num):
# It's important to close these because we're _about_ to fork, and we self.up()
# don't want the forked processes to inherit the open sockets
# for the DB and cache connections (that way lies race conditions) def up(self):
django_connection.close() idx = len(self.workers)
django_cache.close() # It's important to close these because we're _about_ to fork, and we
worker = PoolWorker(target, (idx,)) # don't want the forked processes to inherit the open sockets
try: # for the DB and cache connections (that way lies race conditions)
worker.start() django_connection.close()
except Exception: django_cache.close()
logger.exception('could not fork') worker = self.pool_cls(self.target, (idx,) + self.target_args)
else: self.workers.append(worker)
logger.debug('scaling up worker pid:{}'.format(worker.process.pid)) try:
worker.start()
except Exception:
logger.exception('could not fork')
else:
logger.debug('scaling up worker pid:{}'.format(worker.process.pid))
return idx, worker
def stop(self, signum):
try:
for worker in self.workers:
os.kill(worker.pid, signum)
except Exception:
logger.exception('could not kill {}'.format(worker.pid))

View File

@@ -1,6 +1,9 @@
from datetime import timedelta
import logging import logging
from django.db.models import Q from django.db.models import Q
from django.conf import settings
from django.utils.timezone import now as tz_now
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from awx.main.models import Instance, UnifiedJob, WorkflowJob from awx.main.models import Instance, UnifiedJob, WorkflowJob
@@ -47,6 +50,26 @@ def reap_job(j, status, job_explanation=None):
logger.error(f'{j.log_format} is no longer {status_before}; reaping') logger.error(f'{j.log_format} is no longer {status_before}; reaping')
def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None, ref_time=None):
"""
Reap all jobs in waiting for this instance.
"""
if grace_period is None:
grace_period = settings.JOB_WAITING_GRACE_PERIOD + settings.TASK_MANAGER_TIMEOUT
if instance is None:
hostname = Instance.objects.my_hostname()
else:
hostname = instance.hostname
if ref_time is None:
ref_time = tz_now()
jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=ref_time - timedelta(seconds=grace_period), controller_node=hostname)
if excluded_uuids:
jobs = jobs.exclude(celery_task_id__in=excluded_uuids)
for j in jobs:
reap_job(j, status, job_explanation=job_explanation)
def reap(instance=None, status='failed', job_explanation=None, excluded_uuids=None, ref_time=None): def reap(instance=None, status='failed', job_explanation=None, excluded_uuids=None, ref_time=None):
""" """
Reap all jobs in running for this instance. Reap all jobs in running for this instance.

View File

@@ -19,24 +19,49 @@ def signame(sig):
return dict((k, v) for v, k in signal.__dict__.items() if v.startswith('SIG') and not v.startswith('SIG_'))[sig] return dict((k, v) for v, k in signal.__dict__.items() if v.startswith('SIG') and not v.startswith('SIG_'))[sig]
class AWXConsumerRedis(object): class WorkerSignalHandler:
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class AWXConsumerBase(object):
last_stats = time.time()
def __init__(self, name, worker, queues=[], pool=None):
self.should_stop = False
def __init__(self, name, worker):
self.name = name self.name = name
self.pool = WorkerPool() self.total_messages = 0
self.pool.init_workers(worker.work_loop) self.queues = queues
self.worker = worker
self.pool = pool
if pool is None:
self.pool = WorkerPool()
self.pool.init_workers(self.worker.work_loop)
self.redis = get_redis_client() self.redis = get_redis_client()
def run(self): def run(self, *args, **kwargs):
signal.signal(signal.SIGINT, self.stop) signal.signal(signal.SIGINT, self.stop)
signal.signal(signal.SIGTERM, self.stop) signal.signal(signal.SIGTERM, self.stop)
# Child should implement other things here
def stop(self, signum, frame):
self.should_stop = True
logger.warning('received {}, stopping'.format(signame(signum)))
raise SystemExit()
class AWXConsumerRedis(AWXConsumerBase):
def run(self, *args, **kwargs):
super(AWXConsumerRedis, self).run(*args, **kwargs)
logger.info(f'Callback receiver started with pid={os.getpid()}') logger.info(f'Callback receiver started with pid={os.getpid()}')
db.connection.close() # logs use database, so close connection db.connection.close() # logs use database, so close connection
while True: while True:
time.sleep(60) time.sleep(60)
def stop(self, signum, frame):
logger.warning('received {}, stopping'.format(signame(signum)))
raise SystemExit()

View File

@@ -26,6 +26,7 @@ from awx.main.models.events import emit_event_detail
from awx.main.utils.profiling import AWXProfiler from awx.main.utils.profiling import AWXProfiler
from awx.main.tasks.system import events_processed_hook from awx.main.tasks.system import events_processed_hook
import awx.main.analytics.subsystem_metrics as s_metrics import awx.main.analytics.subsystem_metrics as s_metrics
from .base import WorkerSignalHandler
logger = logging.getLogger('awx.main.commands.run_callback_receiver') logger = logging.getLogger('awx.main.commands.run_callback_receiver')
@@ -56,16 +57,6 @@ def job_stats_wrapup(job_identifier, event=None):
logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier)) logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier))
class WorkerSignalHandler:
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class CallbackBrokerWorker: class CallbackBrokerWorker:
""" """
A worker implementation that deserializes callback event data and persists A worker implementation that deserializes callback event data and persists
@@ -77,13 +68,13 @@ class CallbackBrokerWorker:
MAX_RETRIES = 2 MAX_RETRIES = 2
INDIVIDUAL_EVENT_RETRIES = 3 INDIVIDUAL_EVENT_RETRIES = 3
last_stats = time.time()
last_flush = time.time()
total = 0 total = 0
last_event = '' last_event = ''
prof = None prof = None
def __init__(self): def __init__(self):
self.last_stats = time.time()
self.last_flush = time.time()
self.buff = {} self.buff = {}
self.redis = get_redis_client() self.redis = get_redis_client()
self.subsystem_metrics = s_metrics.CallbackReceiverMetrics(auto_pipe_execute=False) self.subsystem_metrics = s_metrics.CallbackReceiverMetrics(auto_pipe_execute=False)

View File

@@ -1,3 +1,4 @@
import inspect
import logging import logging
import importlib import importlib
import time import time
@@ -36,13 +37,18 @@ def run_callable(body):
if 'guid' in body: if 'guid' in body:
set_guid(body.pop('guid')) set_guid(body.pop('guid'))
_call = resolve_callable(task) _call = resolve_callable(task)
if inspect.isclass(_call):
# the callable is a class, e.g., RunJob; instantiate and
# return its `run()` method
_call = _call().run
log_extra = '' log_extra = ''
logger_method = logger.debug logger_method = logger.debug
if 'time_pub' in body: if ('time_ack' in body) and ('time_pub' in body):
time_publish = time.time() - body['time_pub'] time_publish = body['time_ack'] - body['time_pub']
if time_publish > 5.0: time_waiting = time.time() - body['time_ack']
if time_waiting > 5.0 or time_publish > 5.0:
# If task too a very long time to process, add this information to the log # If task too a very long time to process, add this information to the log
log_extra = f' took {time_publish:.4f} to send message' log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher'
logger_method = logger.info logger_method = logger.info
# don't print kwargs, they often contain launch-time secrets # don't print kwargs, they often contain launch-time secrets
logger_method(f'task {uuid} starting {task}(*{args}){log_extra}') logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')

View File

@@ -428,9 +428,6 @@ class CredentialInputField(JSONSchemaField):
# determine the defined fields for the associated credential type # determine the defined fields for the associated credential type
properties = {} properties = {}
for field in model_instance.credential_type.inputs.get('fields', []): for field in model_instance.credential_type.inputs.get('fields', []):
# Prevent users from providing values for internally resolved fields
if 'internal' in field:
continue
field = field.copy() field = field.copy()
properties[field['id']] = field properties[field['id']] = field
if field.get('choices', []): if field.get('choices', []):
@@ -569,7 +566,6 @@ class CredentialTypeInputField(JSONSchemaField):
}, },
'label': {'type': 'string'}, 'label': {'type': 'string'},
'help_text': {'type': 'string'}, 'help_text': {'type': 'string'},
'internal': {'type': 'boolean'},
'multiline': {'type': 'boolean'}, 'multiline': {'type': 'boolean'},
'secret': {'type': 'boolean'}, 'secret': {'type': 'boolean'},
'ask_at_runtime': {'type': 'boolean'}, 'ask_at_runtime': {'type': 'boolean'},

View File

@@ -1,88 +0,0 @@
import argparse
import inspect
import logging
import os
import sys
import yaml
from django.core.management.base import BaseCommand, CommandError
from django.db import connection
from dispatcherd.cli import (
CONTROL_ARG_SCHEMAS,
DEFAULT_CONFIG_FILE,
_base_cli_parent,
_control_common_parent,
_register_control_arguments,
_build_command_data_from_args,
)
from dispatcherd.config import setup as dispatcher_setup
from dispatcherd.factories import get_control_from_settings
from dispatcherd.service import control_tasks
from awx.main.dispatch.config import get_dispatcherd_config
from awx.main.management.commands.dispatcherd import ensure_no_dispatcherd_env_config
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = 'Dispatcher control operations'
def add_arguments(self, parser):
parser.description = 'Run dispatcherd control commands using awx-manage.'
base_parent = _base_cli_parent()
control_parent = _control_common_parent()
parser._add_container_actions(base_parent)
parser._add_container_actions(control_parent)
subparsers = parser.add_subparsers(dest='command', metavar='command')
subparsers.required = True
shared_parents = [base_parent, control_parent]
for command in control_tasks.__all__:
func = getattr(control_tasks, command, None)
doc = inspect.getdoc(func) or ''
summary = doc.splitlines()[0] if doc else None
command_parser = subparsers.add_parser(
command,
help=summary,
description=doc,
parents=shared_parents,
)
_register_control_arguments(command_parser, CONTROL_ARG_SCHEMAS.get(command))
def handle(self, *args, **options):
command = options.pop('command', None)
if not command:
raise CommandError('No dispatcher control command specified')
for django_opt in ('verbosity', 'traceback', 'no_color', 'force_color', 'skip_checks'):
options.pop(django_opt, None)
log_level = options.pop('log_level', 'DEBUG')
config_path = os.path.abspath(options.pop('config', DEFAULT_CONFIG_FILE))
expected_replies = options.pop('expected_replies', 1)
logging.basicConfig(level=getattr(logging, log_level), stream=sys.stdout)
logger.debug(f"Configured standard out logging at {log_level} level")
default_config = os.path.abspath(DEFAULT_CONFIG_FILE)
ensure_no_dispatcherd_env_config()
if config_path != default_config:
raise CommandError('The config path CLI option is not allowed for the awx-manage command')
if connection.vendor == 'sqlite':
raise CommandError('dispatcherctl is not supported with sqlite3; use a PostgreSQL database')
else:
logger.info('Using config generated from awx.main.dispatch.config.get_dispatcherd_config')
dispatcher_setup(get_dispatcherd_config())
schema_namespace = argparse.Namespace(**options)
data = _build_command_data_from_args(schema_namespace, command)
ctl = get_control_from_settings()
returned = ctl.control_with_reply(command, data=data, expected_replies=expected_replies)
self.stdout.write(yaml.dump(returned, default_flow_style=False))
if len(returned) < expected_replies:
logger.error(f'Obtained only {len(returned)} of {expected_replies}, exiting with non-zero code')
raise CommandError('dispatcherctl returned fewer replies than expected')

View File

@@ -1,85 +0,0 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved
import copy
import hashlib
import json
import logging
import logging.config
import os
from django.conf import settings
from django.core.cache import cache as django_cache
from django.core.management.base import BaseCommand, CommandError
from django.db import connection
from dispatcherd.config import setup as dispatcher_setup
from awx.main.dispatch.config import get_dispatcherd_config
logger = logging.getLogger('awx.main.dispatch')
from dispatcherd import run_service
def _json_default(value):
if isinstance(value, set):
return sorted(value)
if isinstance(value, tuple):
return list(value)
return str(value)
def _hash_config(config):
serialized = json.dumps(config, sort_keys=True, separators=(',', ':'), default=_json_default)
return hashlib.sha256(serialized.encode('utf-8')).hexdigest()
def ensure_no_dispatcherd_env_config():
if os.getenv('DISPATCHERD_CONFIG_FILE'):
raise CommandError('DISPATCHERD_CONFIG_FILE is set but awx-manage dispatcherd uses dynamic config from code')
class Command(BaseCommand):
help = (
'Run the background task service, this is the supported entrypoint since the introduction of dispatcherd as a library. '
'This replaces the prior awx-manage run_dispatcher service, and control actions are at awx-manage dispatcherctl.'
)
def add_arguments(self, parser):
return
def handle(self, *arg, **options):
ensure_no_dispatcherd_env_config()
self.configure_dispatcher_logging()
config = get_dispatcherd_config(for_service=True)
config_hash = _hash_config(config)
logger.info(
'Using dispatcherd config generated from awx.main.dispatch.config.get_dispatcherd_config (sha256=%s)',
config_hash,
)
# Close the connection, because the pg_notify broker will create new async connection
connection.close()
django_cache.close()
dispatcher_setup(config)
run_service()
def configure_dispatcher_logging(self):
# Apply special log rule for the parent process
special_logging = copy.deepcopy(settings.LOGGING)
changed_handlers = []
for handler_name, handler_config in special_logging.get('handlers', {}).items():
filters = handler_config.get('filters', [])
if 'dynamic_level_filter' in filters:
handler_config['filters'] = [flt for flt in filters if flt != 'dynamic_level_filter']
changed_handlers.append(handler_name)
logger.info(f'Dispatcherd main process replaced log level filter for handlers: {changed_handlers}')
# Apply the custom logging level here, before the asyncio code starts
special_logging.setdefault('loggers', {}).setdefault('dispatcherd', {})
special_logging['loggers']['dispatcherd']['level'] = settings.LOG_AGGREGATOR_LEVEL
logging.config.dictConfig(special_logging)

View File

@@ -3,6 +3,7 @@
import redis import redis
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
import redis.exceptions import redis.exceptions
@@ -35,7 +36,11 @@ class Command(BaseCommand):
raise CommandError(f'Callback receiver could not connect to redis, error: {exc}') raise CommandError(f'Callback receiver could not connect to redis, error: {exc}')
try: try:
consumer = AWXConsumerRedis('callback_receiver', CallbackBrokerWorker()) consumer = AWXConsumerRedis(
'callback_receiver',
CallbackBrokerWorker(),
queues=[getattr(settings, 'CALLBACK_QUEUE', '')],
)
consumer.run() consumer.run()
except KeyboardInterrupt: except KeyboardInterrupt:
print('Terminating Callback Receiver') print('Terminating Callback Receiver')

View File

@@ -1,20 +1,26 @@
# Copyright (c) 2015 Ansible, Inc. # Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved. # All Rights Reserved.
import logging import logging
import logging.config
import yaml import yaml
import copy
from django.core.management.base import CommandError from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.core.cache import cache as django_cache
from django.db import connection
from dispatcherd.factories import get_control_from_settings from dispatcherd.factories import get_control_from_settings
from dispatcherd import run_service
from dispatcherd.config import setup as dispatcher_setup
from awx.main.management.commands.dispatcherd import Command as DispatcherdCommand from awx.main.dispatch.config import get_dispatcherd_config
logger = logging.getLogger('awx.main.dispatch') logger = logging.getLogger('awx.main.dispatch')
class Command(DispatcherdCommand): class Command(BaseCommand):
help = 'Launch the task dispatcher (deprecated; use awx-manage dispatcherd)' help = 'Launch the task dispatcher'
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--status', dest='status', action='store_true', help='print the internal state of any running dispatchers') parser.add_argument('--status', dest='status', action='store_true', help='print the internal state of any running dispatchers')
@@ -28,10 +34,8 @@ class Command(DispatcherdCommand):
'Only running tasks can be canceled, queued tasks must be started before they can be canceled.' 'Only running tasks can be canceled, queued tasks must be started before they can be canceled.'
), ),
) )
super().add_arguments(parser)
def handle(self, *args, **options): def handle(self, *arg, **options):
logger.warning('awx-manage run_dispatcher is deprecated; use awx-manage dispatcherd')
if options.get('status'): if options.get('status'):
ctl = get_control_from_settings() ctl = get_control_from_settings()
running_data = ctl.control_with_reply('status') running_data = ctl.control_with_reply('status')
@@ -61,4 +65,28 @@ class Command(DispatcherdCommand):
results.append(result) results.append(result)
print(yaml.dump(results, default_flow_style=False)) print(yaml.dump(results, default_flow_style=False))
return return
return super().handle(*args, **options)
self.configure_dispatcher_logging()
# Close the connection, because the pg_notify broker will create new async connection
connection.close()
django_cache.close()
dispatcher_setup(get_dispatcherd_config(for_service=True))
run_service()
dispatcher_setup(get_dispatcherd_config(for_service=True))
run_service()
def configure_dispatcher_logging(self):
# Apply special log rule for the parent process
special_logging = copy.deepcopy(settings.LOGGING)
for handler_name, handler_config in special_logging.get('handlers', {}).items():
filters = handler_config.get('filters', [])
if 'dynamic_level_filter' in filters:
handler_config['filters'] = [flt for flt in filters if flt != 'dynamic_level_filter']
logger.info(f'Dispatcherd main process replaced log level filter for {handler_name} handler')
# Apply the custom logging level here, before the asyncio code starts
special_logging.setdefault('loggers', {}).setdefault('dispatcherd', {})
special_logging['loggers']['dispatcherd']['level'] = settings.LOG_AGGREGATOR_LEVEL
logging.config.dictConfig(special_logging)

View File

@@ -21,6 +21,6 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RunPython(setup_tower_managed_defaults, migrations.RunPython.noop), migrations.RunPython(setup_tower_managed_defaults),
migrations.RunPython(setup_rbac_role_system_administrator, migrations.RunPython.noop), migrations.RunPython(setup_rbac_role_system_administrator),
] ]

View File

@@ -98,5 +98,5 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RunPython(convert_controller_role_definitions, migrations.RunPython.noop), migrations.RunPython(convert_controller_role_definitions),
] ]

View File

@@ -3,15 +3,19 @@ from django.db import migrations, models
from awx.main.migrations._create_system_jobs import delete_clear_tokens_sjt from awx.main.migrations._create_system_jobs import delete_clear_tokens_sjt
# --- START of function merged from 0203_rename_github_app_kind.py ---
def update_github_app_kind(apps, schema_editor): def update_github_app_kind(apps, schema_editor):
""" """
Updates the 'namespace' field for CredentialType records Updates the 'kind' field for CredentialType records
from 'github_app' to 'github_app_lookup'. from 'github_app' to 'github_app_lookup'.
This addresses a change in the entry point key for the GitHub App plugin. This addresses a change in the entry point key for the GitHub App plugin.
""" """
CredentialType = apps.get_model('main', 'CredentialType') CredentialType = apps.get_model('main', 'CredentialType')
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
CredentialType.objects.using(db_alias).filter(namespace='github_app').update(namespace='github_app_lookup') CredentialType.objects.using(db_alias).filter(kind='github_app').update(kind='github_app_lookup')
# --- END of function merged from 0203_rename_github_app_kind.py ---
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -114,5 +118,7 @@ class Migration(migrations.Migration):
max_length=32, max_length=32,
), ),
), ),
# --- START of operations merged from 0203_rename_github_app_kind.py ---
migrations.RunPython(update_github_app_kind, migrations.RunPython.noop), migrations.RunPython(update_github_app_kind, migrations.RunPython.noop),
# --- END of operations merged from 0203_rename_github_app_kind.py ---
] ]

View File

@@ -1,29 +0,0 @@
# Generated by Django 5.2.8 on 2026-02-20 03:39
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('main', '0204_squashed_deletions'),
]
operations = [
migrations.AlterModelOptions(
name='instancegroup',
options={
'default_permissions': ('change', 'delete', 'view'),
'ordering': ('pk',),
'permissions': [('use_instancegroup', 'Can use instance group in a preference list of a resource')],
},
),
migrations.AlterModelOptions(
name='workflowjobnode',
options={'ordering': ('pk',)},
),
migrations.AlterModelOptions(
name='workflowjobtemplatenode',
options={'ordering': ('pk',)},
),
]

View File

@@ -386,6 +386,7 @@ class gce(PluginFileInjector):
# auth related items # auth related items
ret['auth_kind'] = "serviceaccount" ret['auth_kind'] = "serviceaccount"
filters = []
# TODO: implement gce group_by options # TODO: implement gce group_by options
# gce never processed the group_by field, if it had, we would selectively # gce never processed the group_by field, if it had, we would selectively
# apply those options here, but it did not, so all groups are added here # apply those options here, but it did not, so all groups are added here
@@ -419,6 +420,8 @@ class gce(PluginFileInjector):
if keyed_groups: if keyed_groups:
ret['keyed_groups'] = keyed_groups ret['keyed_groups'] = keyed_groups
if filters:
ret['filters'] = filters
if compose_dict: if compose_dict:
ret['compose'] = compose_dict ret['compose'] = compose_dict
if inventory_source.source_regions and 'all' not in inventory_source.source_regions: if inventory_source.source_regions and 'all' not in inventory_source.source_regions:

View File

@@ -315,11 +315,12 @@ class PrimordialModel(HasEditsMixin, CreatedModifiedModel):
) )
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(PrimordialModel, self).__init__(*args, **kwargs) r = super(PrimordialModel, self).__init__(*args, **kwargs)
if self.pk: if self.pk:
self._prior_values_store = self._get_fields_snapshot() self._prior_values_store = self._get_fields_snapshot()
else: else:
self._prior_values_store = {} self._prior_values_store = {}
return r
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
update_fields = kwargs.get('update_fields', []) update_fields = kwargs.get('update_fields', [])

View File

@@ -28,7 +28,6 @@ from rest_framework.serializers import ValidationError as DRFValidationError
from ansible_base.lib.utils.db import advisory_lock from ansible_base.lib.utils.db import advisory_lock
# AWX # AWX
from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES
from awx.api.versioning import reverse from awx.api.versioning import reverse
from awx.main.fields import ( from awx.main.fields import (
ImplicitRoleField, ImplicitRoleField,
@@ -243,29 +242,6 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
needed.append('vault_password') needed.append('vault_password')
return needed return needed
@functools.cached_property
def context(self):
"""
Property for storing runtime context during credential resolution.
The context is a dict keyed by CredentialInputSource PK, where each value
is a dict of runtime fields for that input source. Example::
{
<input_source_pk>: {
"workload_identity_token": "<jwt_token>"
},
<another_input_source_pk>: {
"workload_identity_token": "<different_jwt_token>"
},
}
This structure allows each input source to have its own set of runtime
values, avoiding conflicts when a credential has multiple input sources
with different configurations (e.g., different JWT audiences).
"""
return {}
@cached_property @cached_property
def dynamic_input_fields(self): def dynamic_input_fields(self):
# if the credential is not yet saved we can't access the input_sources # if the credential is not yet saved we can't access the input_sources
@@ -391,7 +367,7 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
def _get_dynamic_input(self, field_name): def _get_dynamic_input(self, field_name):
for input_source in self.input_sources.all(): for input_source in self.input_sources.all():
if input_source.input_field_name == field_name: if input_source.input_field_name == field_name:
return input_source.get_input_value(context=self.context) return input_source.get_input_value()
else: else:
raise ValueError('{} is not a dynamic input field'.format(field_name)) raise ValueError('{} is not a dynamic input field'.format(field_name))
@@ -459,15 +435,13 @@ class CredentialType(CommonModelNameNotUnique):
def from_db(cls, db, field_names, values): def from_db(cls, db, field_names, values):
instance = super(CredentialType, cls).from_db(db, field_names, values) instance = super(CredentialType, cls).from_db(db, field_names, values)
if instance.managed and instance.namespace and instance.kind != "external": if instance.managed and instance.namespace and instance.kind != "external":
native = ManagedCredentialType.registry.get(instance.namespace) native = ManagedCredentialType.registry[instance.namespace]
if native: instance.inputs = native.inputs
instance.inputs = native.inputs instance.injectors = native.injectors
instance.injectors = native.injectors instance.custom_injectors = getattr(native, 'custom_injectors', None)
instance.custom_injectors = getattr(native, 'custom_injectors', None)
elif instance.namespace and instance.kind == "external": elif instance.namespace and instance.kind == "external":
native = ManagedCredentialType.registry.get(instance.namespace) native = ManagedCredentialType.registry[instance.namespace]
if native: instance.inputs = native.inputs
instance.inputs = native.inputs
return instance return instance
@@ -648,15 +622,7 @@ class CredentialInputSource(PrimordialModel):
raise ValidationError(_('Input field must be defined on target credential (options are {}).'.format(', '.join(sorted(defined_fields))))) raise ValidationError(_('Input field must be defined on target credential (options are {}).'.format(', '.join(sorted(defined_fields)))))
return self.input_field_name return self.input_field_name
def get_input_value(self, context: dict | None = None): def get_input_value(self):
"""
Retrieve the value from the external credential backend.
Args:
context: Optional runtime context dict passed from the target credential.
"""
if context is None:
context = {}
backend = self.source_credential.credential_type.plugin.backend backend = self.source_credential.credential_type.plugin.backend
backend_kwargs = {} backend_kwargs = {}
for field_name, value in self.source_credential.inputs.items(): for field_name, value in self.source_credential.inputs.items():
@@ -667,17 +633,6 @@ class CredentialInputSource(PrimordialModel):
backend_kwargs.update(self.metadata) backend_kwargs.update(self.metadata)
# Resolve internal fields from the per-input-source context.
# The context dict is keyed by input source PK, e.g.:
# {42: {"workload_identity_token": "eyJ..."}, 43: {"workload_identity_token": "eyX..."}}
# This allows each input source to carry its own runtime values.
input_source_context = context.get(self.pk, {})
for field in self.source_credential.credential_type.inputs.get('fields', []):
if field.get('internal'):
value = input_source_context.get(field['id'])
if value is not None:
backend_kwargs[field['id']] = value
with set_environ(**settings.AWX_TASK_ENV): with set_environ(**settings.AWX_TASK_ENV):
return backend(**backend_kwargs) return backend(**backend_kwargs)
@@ -686,20 +641,13 @@ class CredentialInputSource(PrimordialModel):
return reverse(view_name, kwargs={'pk': self.pk}, request=request) return reverse(view_name, kwargs={'pk': self.pk}, request=request)
def _is_oidc_namespace_disabled(ns):
"""Check if a credential namespace should be skipped based on the OIDC feature flag."""
return ns in OIDC_CREDENTIAL_TYPE_NAMESPACES and not getattr(settings, 'FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED', False)
def load_credentials(): def load_credentials():
awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')} awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')}
supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')} supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')}
plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points} plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points}
for ns, ep in plugin_entry_points.items(): for ns, ep in plugin_entry_points.items():
if _is_oidc_namespace_disabled(ns):
continue
cred_plugin = ep.load() cred_plugin = ep.load()
if not hasattr(cred_plugin, 'inputs'): if not hasattr(cred_plugin, 'inputs'):
setattr(cred_plugin, 'inputs', {}) setattr(cred_plugin, 'inputs', {})
@@ -718,8 +666,5 @@ def load_credentials():
credential_plugins = {} credential_plugins = {}
for ns, ep in credential_plugins.items(): for ns, ep in credential_plugins.items():
if _is_oidc_namespace_disabled(ns):
continue
plugin = ep.load() plugin = ep.load()
CredentialType.load_plugin(ns, plugin) CredentialType.load_plugin(ns, plugin)

View File

@@ -50,8 +50,9 @@ class HasPolicyEditsMixin(HasEditsMixin):
abstract = True abstract = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BaseModel, self).__init__(*args, **kwargs) r = super(BaseModel, self).__init__(*args, **kwargs)
self._prior_values_store = self._get_fields_snapshot() self._prior_values_store = self._get_fields_snapshot()
return r
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
super(BaseModel, self).save(*args, **kwargs) super(BaseModel, self).save(*args, **kwargs)
@@ -485,7 +486,6 @@ class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin, ResourceMi
class Meta: class Meta:
app_label = 'main' app_label = 'main'
ordering = ('pk',)
permissions = [('use_instancegroup', 'Can use instance group in a preference list of a resource')] permissions = [('use_instancegroup', 'Can use instance group in a preference list of a resource')]
# Since this has no direct organization field only superuser can add, so remove add permission # Since this has no direct organization field only superuser can add, so remove add permission
default_permissions = ('change', 'delete', 'view') default_permissions = ('change', 'delete', 'view')

View File

@@ -845,21 +845,6 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
def get_notification_friendly_name(self): def get_notification_friendly_name(self):
return "Job" return "Job"
def get_source_hosts_for_constructed_inventory(self):
"""Return a QuerySet of the source (input inventory) hosts for a constructed inventory.
Constructed inventory hosts have an instance_id pointing to the real
host in the input inventory. This resolves those references and returns
a proper QuerySet (never a list), suitable for use with finish_fact_cache.
"""
Host = JobHostSummary._meta.get_field('host').related_model
if not self.inventory_id:
return Host.objects.none()
id_field = Host._meta.get_field('id')
return Host.objects.filter(id__in=self.inventory.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field))).only(
*HOST_FACTS_FIELDS
)
def get_hosts_for_fact_cache(self): def get_hosts_for_fact_cache(self):
""" """
Builds the queryset to use for writing or finalizing the fact cache Builds the queryset to use for writing or finalizing the fact cache
@@ -867,15 +852,17 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
For constructed inventories, that means the original (input inventory) hosts For constructed inventories, that means the original (input inventory) hosts
when slicing, that means only returning hosts in that slice when slicing, that means only returning hosts in that slice
""" """
Host = JobHostSummary._meta.get_field('host').related_model
if not self.inventory_id: if not self.inventory_id:
Host = JobHostSummary._meta.get_field('host').related_model
return Host.objects.none() return Host.objects.none()
if self.inventory.kind == 'constructed': if self.inventory.kind == 'constructed':
host_qs = self.get_source_hosts_for_constructed_inventory() id_field = Host._meta.get_field('id')
host_qs = Host.objects.filter(id__in=self.inventory.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field)))
else: else:
host_qs = self.inventory.hosts.only(*HOST_FACTS_FIELDS) host_qs = self.inventory.hosts
host_qs = host_qs.only(*HOST_FACTS_FIELDS)
host_qs = self.inventory.get_sliced_hosts(host_qs, self.job_slice_number, self.job_slice_count) host_qs = self.inventory.get_sliced_hosts(host_qs, self.job_slice_number, self.job_slice_count)
return host_qs return host_qs

View File

@@ -188,16 +188,6 @@ class SurveyJobTemplateMixin(models.Model):
runtime_extra_vars.pop(variable_key) runtime_extra_vars.pop(variable_key)
if default is not None: if default is not None:
# do not add variables that contain an empty string, are not required and are not present in extra_vars
# password fields must be skipped, because default values have special behaviour
if (
default == ''
and not survey_element.get('required')
and survey_element.get('type') != 'password'
and variable_key not in runtime_extra_vars
):
continue
decrypted_default = default decrypted_default = default
if survey_element['type'] == "password" and isinstance(decrypted_default, str) and decrypted_default.startswith('$encrypted$'): if survey_element['type'] == "password" and isinstance(decrypted_default, str) and decrypted_default.startswith('$encrypted$'):
decrypted_default = decrypt_value(get_encryption_key('value', pk=None), decrypted_default) decrypted_default = decrypt_value(get_encryption_key('value', pk=None), decrypted_default)

View File

@@ -10,6 +10,7 @@ import json
import logging import logging
import os import os
import re import re
import socket
import subprocess import subprocess
import tempfile import tempfile
from collections import OrderedDict from collections import OrderedDict
@@ -918,7 +919,7 @@ class UnifiedJob(
# If we have a start and finished time, and haven't already calculated # If we have a start and finished time, and haven't already calculated
# out the time that elapsed, do so. # out the time that elapsed, do so.
if self.started and self.finished and self.elapsed == decimal.Decimal(0): if self.started and self.finished and self.elapsed == 0.0:
td = self.finished - self.started td = self.finished - self.started
elapsed = decimal.Decimal(td.total_seconds()) elapsed = decimal.Decimal(td.total_seconds())
self.elapsed = elapsed.quantize(dq) self.elapsed = elapsed.quantize(dq)
@@ -1354,6 +1355,8 @@ class UnifiedJob(
status_data['instance_group_name'] = None status_data['instance_group_name'] = None
elif status in ['successful', 'failed', 'canceled'] and self.finished: elif status in ['successful', 'failed', 'canceled'] and self.finished:
status_data['finished'] = datetime.datetime.strftime(self.finished, "%Y-%m-%dT%H:%M:%S.%fZ") status_data['finished'] = datetime.datetime.strftime(self.finished, "%Y-%m-%dT%H:%M:%S.%fZ")
elif status == 'running':
status_data['started'] = datetime.datetime.strftime(self.finished, "%Y-%m-%dT%H:%M:%S.%fZ")
status_data.update(self.websocket_emit_data()) status_data.update(self.websocket_emit_data())
status_data['group_name'] = 'jobs' status_data['group_name'] = 'jobs'
if getattr(self, 'unified_job_template_id', None): if getattr(self, 'unified_job_template_id', None):
@@ -1485,17 +1488,40 @@ class UnifiedJob(
return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id) return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id)
return None return None
def fallback_cancel(self):
if not self.celery_task_id:
self.refresh_from_db(fields=['celery_task_id'])
self.cancel_dispatcher_process()
def cancel_dispatcher_process(self): def cancel_dispatcher_process(self):
"""Returns True if dispatcher running this job acknowledged request and sent SIGTERM""" """Returns True if dispatcher running this job acknowledged request and sent SIGTERM"""
if not self.celery_task_id: if not self.celery_task_id:
return False return False
# Special case for task manager (used during workflow job cancellation)
if not connection.get_autocommit():
try:
ctl = get_control_from_settings()
ctl.control('cancel', data={'uuid': self.celery_task_id})
except Exception:
logger.exception("Error sending cancel command to dispatcher")
return True # task manager itself needs to act under assumption that cancel was received
# Standard case with reply
try: try:
logger.info(f'Sending cancel message to pg_notify channel {self.controller_node} for task {self.celery_task_id}') timeout = 5
ctl = get_control_from_settings(default_publish_channel=self.controller_node)
ctl.control('cancel', data={'uuid': self.celery_task_id}) ctl = get_control_from_settings()
results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout)
# Check if cancel was successful by checking if we got any results
return bool(results and len(results) > 0)
except socket.timeout:
logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s')
except Exception: except Exception:
logger.exception("Error sending cancel command to dispatcher") logger.exception("error encountered when checking task status")
return False # whether confirmation was obtained
def cancel(self, job_explanation=None, is_chain=False): def cancel(self, job_explanation=None, is_chain=False):
if self.can_cancel: if self.can_cancel:
@@ -1518,13 +1544,19 @@ class UnifiedJob(
# the job control process will use the cancel_flag to distinguish a shutdown from a cancel # the job control process will use the cancel_flag to distinguish a shutdown from a cancel
self.save(update_fields=cancel_fields) self.save(update_fields=cancel_fields)
# Be extra sure we have the task id, in case job is transitioning into running right now controller_notified = False
if not self.celery_task_id:
self.refresh_from_db(fields=['celery_task_id', 'controller_node'])
# send pg_notify message to cancel, will not send until transaction completes
if self.celery_task_id: if self.celery_task_id:
self.cancel_dispatcher_process() controller_notified = self.cancel_dispatcher_process()
# If a SIGTERM signal was sent to the control process, and acked by the dispatcher
# then we want to let its own cleanup change status, otherwise change status now
if not controller_notified:
if self.status != 'canceled':
self.status = 'canceled'
self.save(update_fields=['status'])
# Avoid race condition where we have stale model from pending state but job has already started,
# its checking signal but not cancel_flag, so re-send signal after updating cancel fields
self.fallback_cancel()
return self.cancel_flag return self.cancel_flag

View File

@@ -200,7 +200,6 @@ class WorkflowJobTemplateNode(WorkflowNodeBase):
indexes = [ indexes = [
models.Index(fields=['identifier']), models.Index(fields=['identifier']),
] ]
ordering = ('pk',)
def get_absolute_url(self, request=None): def get_absolute_url(self, request=None):
return reverse('api:workflow_job_template_node_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:workflow_job_template_node_detail', kwargs={'pk': self.pk}, request=request)
@@ -287,7 +286,6 @@ class WorkflowJobNode(WorkflowNodeBase):
models.Index(fields=["identifier", "workflow_job"]), models.Index(fields=["identifier", "workflow_job"]),
models.Index(fields=['identifier']), models.Index(fields=['identifier']),
] ]
ordering = ('pk',)
@property @property
def event_processing_finished(self): def event_processing_finished(self):
@@ -787,7 +785,7 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio
def cancel_dispatcher_process(self): def cancel_dispatcher_process(self):
# WorkflowJobs don't _actually_ run anything in the dispatcher, so # WorkflowJobs don't _actually_ run anything in the dispatcher, so
# there's no point in asking the dispatcher if it knows about this task # there's no point in asking the dispatcher if it knows about this task
return return True
class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin): class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin):
@@ -918,17 +916,6 @@ class WorkflowApproval(UnifiedJob, JobNotificationMixin):
ScheduleWorkflowManager().schedule() ScheduleWorkflowManager().schedule()
return reverse('api:workflow_approval_deny', kwargs={'pk': self.pk}, request=request) return reverse('api:workflow_approval_deny', kwargs={'pk': self.pk}, request=request)
def cancel(self, job_explanation=None, is_chain=False):
# WorkflowApprovals have no dispatcher process (they wait for human
# input) and are excluded from TaskManager processing, so the base
# cancel() would only set cancel_flag without ever transitioning the
# status. We call super() for the flag, then transition directly.
has_already_canceled = bool(self.status == 'canceled')
super().cancel(job_explanation=job_explanation, is_chain=is_chain)
if self.status != 'canceled' and not has_already_canceled:
self.status = 'canceled'
self.save(update_fields=['status'])
def signal_start(self, **kwargs): def signal_start(self, **kwargs):
can_start = super(WorkflowApproval, self).signal_start(**kwargs) can_start = super(WorkflowApproval, self).signal_start(**kwargs)
self.started = self.created self.started = self.created

View File

@@ -76,12 +76,10 @@ class GrafanaBackend(AWXBaseEmailBackend, CustomNotificationBase):
grafana_headers = {} grafana_headers = {}
if 'started' in m.body: if 'started' in m.body:
try: try:
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc) epoch = datetime.datetime.utcfromtimestamp(0)
grafana_data['time'] = grafana_data['timeEnd'] = int( grafana_data['time'] = grafana_data['timeEnd'] = int((dp.parse(m.body['started']).replace(tzinfo=None) - epoch).total_seconds() * 1000)
(dp.parse(m.body['started']).replace(tzinfo=datetime.timezone.utc) - epoch).total_seconds() * 1000
)
if m.body.get('finished'): if m.body.get('finished'):
grafana_data['timeEnd'] = int((dp.parse(m.body['finished']).replace(tzinfo=datetime.timezone.utc) - epoch).total_seconds() * 1000) grafana_data['timeEnd'] = int((dp.parse(m.body['finished']).replace(tzinfo=None) - epoch).total_seconds() * 1000)
except ValueError: except ValueError:
logger.error(smart_str(_("Error converting time {} or timeEnd {} to int.").format(m.body['started'], m.body['finished']))) logger.error(smart_str(_("Error converting time {} or timeEnd {} to int.").format(m.body['started'], m.body['finished'])))
if not self.fail_silently: if not self.fail_silently:

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2016 Ansible, Inc. # Copyright (c) 2016 Ansible, Inc.
# All Rights Reserved. # All Rights Reserved.
import base64
import json import json
import logging import logging
import requests import requests
@@ -85,25 +84,20 @@ class WebhookBackend(AWXBaseEmailBackend, CustomNotificationBase):
if resp.status_code not in [301, 307]: if resp.status_code not in [301, 307]:
break break
# convert the url to a base64 encoded string for safe logging
url_log_safe = base64.b64encode(url.encode('UTF-8'))
# get the next URL to try
url_next = resp.headers.get("Location", None)
url_next_log_safe = base64.b64encode(url_next.encode('UTF-8')) if url_next else b'None'
# we've hit a redirect. extract the redirect URL out of the first response header and try again # we've hit a redirect. extract the redirect URL out of the first response header and try again
logger.warning(f"Received a {resp.status_code} from {url_log_safe}, trying to reach redirect url {url_next_log_safe}; attempt #{retries+1}") logger.warning(
f"Received a {resp.status_code} from {url}, trying to reach redirect url {resp.headers.get('Location', None)}; attempt #{retries+1}"
)
# take the first redirect URL in the response header and try that # take the first redirect URL in the response header and try that
url = url_next url = resp.headers.get("Location", None)
if url is None: if url is None:
err = f"Webhook notification received redirect to a blank URL from {url_log_safe}. Response headers={resp.headers}" err = f"Webhook notification received redirect to a blank URL from {url}. Response headers={resp.headers}"
break break
else: else:
# no break condition in the loop encountered; therefore we have hit the maximum number of retries # no break condition in the loop encountered; therefore we have hit the maximum number of retries
err = f"Webhook notification max number of retries [{self.MAX_RETRIES}] exceeded. Failed to send webhook notification to {url_log_safe}" err = f"Webhook notification max number of retries [{self.MAX_RETRIES}] exceeded. Failed to send webhook notification to {url}"
if resp.status_code >= 400: if resp.status_code >= 400:
err = f"Error sending webhook notification: {resp.status_code}" err = f"Error sending webhook notification: {resp.status_code}"

View File

@@ -19,8 +19,13 @@ class ActivityStreamRegistrar(object):
pre_delete.connect(activity_stream_delete, sender=model, dispatch_uid=str(self.__class__) + str(model) + "_delete") pre_delete.connect(activity_stream_delete, sender=model, dispatch_uid=str(self.__class__) + str(model) + "_delete")
for m2mfield in model._meta.many_to_many: for m2mfield in model._meta.many_to_many:
m2m_attr = getattr(model, m2mfield.name) try:
m2m_changed.connect(activity_stream_associate, sender=m2m_attr.through, dispatch_uid=str(self.__class__) + str(m2m_attr.through) + "_associate") m2m_attr = getattr(model, m2mfield.name)
m2m_changed.connect(
activity_stream_associate, sender=m2m_attr.through, dispatch_uid=str(self.__class__) + str(m2m_attr.through) + "_associate"
)
except AttributeError:
pass
def disconnect(self, model): def disconnect(self, model):
if model in self.models: if model in self.models:

View File

@@ -48,6 +48,11 @@ class SimpleDAG(object):
''' '''
self.node_to_edges_by_label = dict() self.node_to_edges_by_label = dict()
def __contains__(self, obj):
if self.node['node_object'] in self.node_obj_to_node_index:
return True
return False
def __len__(self): def __len__(self):
return len(self.nodes) return len(self.nodes)

View File

@@ -122,11 +122,8 @@ class WorkflowDAG(SimpleDAG):
if not job: if not job:
continue continue
elif job.can_cancel: elif job.can_cancel:
cancel_finished = False
job.cancel() job.cancel()
# If the job is not yet in a terminal state after .cancel(),
# the TaskManager still needs to process it.
if job.status not in ('successful', 'failed', 'canceled', 'error'):
cancel_finished = False
return cancel_finished return cancel_finished
def is_workflow_done(self): def is_workflow_done(self):

View File

@@ -196,10 +196,6 @@ class WorkflowManager(TaskBase):
workflow_job.start_args = '' # blank field to remove encrypted passwords workflow_job.start_args = '' # blank field to remove encrypted passwords
workflow_job.save(update_fields=['status', 'start_args']) workflow_job.save(update_fields=['status', 'start_args'])
status_changed = True status_changed = True
else:
# Speed-up: schedule the task manager so it can process the
# canceled pending jobs without waiting for the next cycle.
ScheduleTaskManager().schedule()
else: else:
dnr_nodes = dag.mark_dnr_nodes() dnr_nodes = dag.mark_dnr_nodes()
WorkflowJobNode.objects.bulk_update(dnr_nodes, ['do_not_run']) WorkflowJobNode.objects.bulk_update(dnr_nodes, ['do_not_run'])
@@ -447,29 +443,17 @@ class TaskManager(TaskBase):
self.controlplane_ig = self.tm_models.instance_groups.controlplane_ig self.controlplane_ig = self.tm_models.instance_groups.controlplane_ig
def process_job_dep_failures(self, task): def process_job_dep_failures(self, task):
"""If job depends on a job that has failed or been canceled, mark as failed. """If job depends on a job that has failed, mark as failed and handle misc stuff."""
Returns True if a dep failure was found, False otherwise.
"""
for dep in task.dependent_jobs.all(): for dep in task.dependent_jobs.all():
# if we detect a failed, error, or canceled dependency, go ahead and fail this task. # if we detect a failed or error dependency, go ahead and fail this task.
if dep.status in ("error", "failed", "canceled"): if dep.status in ("error", "failed"):
task.status = 'failed' task.status = 'failed'
if dep.status == 'canceled': logger.warning(f'Previous task failed task: {task.id} dep: {dep.id} task manager')
logger.warning(f'Previous task canceled, failing task: {task.id} dep: {dep.id} task manager') task.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
task.job_explanation = 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % ( get_type_for_model(type(dep)),
get_type_for_model(type(dep)), dep.name,
dep.name, dep.id,
dep.id, )
)
ScheduleWorkflowManager().schedule() # speedup for dependency chains in workflow, on workflow cancel
else:
logger.warning(f'Previous task failed, failing task: {task.id} dep: {dep.id} task manager')
task.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
get_type_for_model(type(dep)),
dep.name,
dep.id,
)
task.save(update_fields=['status', 'job_explanation']) task.save(update_fields=['status', 'job_explanation'])
task.websocket_emit_status('failed') task.websocket_emit_status('failed')
self.pre_start_failed.append(task.id) self.pre_start_failed.append(task.id)
@@ -561,17 +545,8 @@ class TaskManager(TaskBase):
logger.warning("Task manager has reached time out while processing pending jobs, exiting loop early") logger.warning("Task manager has reached time out while processing pending jobs, exiting loop early")
break break
if task.cancel_flag: has_failed = self.process_job_dep_failures(task)
logger.debug(f"Canceling pending task {task.log_format} because cancel_flag is set") if has_failed:
task.status = 'canceled'
task.job_explanation = gettext_noop("This job was canceled before it started.")
task.save(update_fields=['status', 'job_explanation'])
task.websocket_emit_status('canceled')
self.pre_start_failed.append(task.id)
ScheduleWorkflowManager().schedule()
continue
if self.process_job_dep_failures(task):
continue continue
blocked_by = self.job_blocked_by(task) blocked_by = self.job_blocked_by(task)

View File

@@ -277,6 +277,7 @@ class RunnerCallback:
def artifacts_handler(self, artifact_dir): def artifacts_handler(self, artifact_dir):
success, query_file_contents = try_load_query_file(artifact_dir) success, query_file_contents = try_load_query_file(artifact_dir)
if success: if success:
self.delay_update(event_queries_processed=False)
collections_info = collect_queries(query_file_contents) collections_info = collect_queries(query_file_contents)
for collection, data in collections_info.items(): for collection, data in collections_info.items():
version = data['version'] version = data['version']
@@ -300,24 +301,6 @@ class RunnerCallback:
else: else:
logger.warning(f'The file {COLLECTION_FILENAME} unexpectedly did not contain ansible_version') logger.warning(f'The file {COLLECTION_FILENAME} unexpectedly did not contain ansible_version')
# Write event_queries_processed and installed_collections directly
# to the DB instead of using delay_update. delay_update defers
# writes until the final job status save, but
# events_processed_hook (called from both the task runner after
# the final save and the callback receiver after the wrapup
# event) needs event_queries_processed=False visible in the DB
# to dispatch save_indirect_host_entries. The field defaults to
# True, so without a direct write the hook would see True and
# skip the dispatch. installed_collections is also written
# directly so it is available if the callback receiver
# dispatches before the final save.
from awx.main.models import Job
db_updates = {'event_queries_processed': False}
if 'installed_collections' in query_file_contents:
db_updates['installed_collections'] = query_file_contents['installed_collections']
Job.objects.filter(id=self.instance.id).update(**db_updates)
self.artifacts_processed = True self.artifacts_processed = True

View File

@@ -25,8 +25,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
log_data = log_data or {} log_data = log_data or {}
log_data['inventory_id'] = inventory_id log_data['inventory_id'] = inventory_id
log_data['written_ct'] = 0 log_data['written_ct'] = 0
# Dict mapping host name -> bool (True if a fact file was written) hosts_cached = []
hosts_cached = {}
# Create the fact_cache directory inside artifacts_dir # Create the fact_cache directory inside artifacts_dir
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache') fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
@@ -38,14 +37,13 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
last_write_time = None last_write_time = None
for host in hosts: for host in hosts:
hosts_cached.append(host.name)
if not host.ansible_facts_modified or (timeout and host.ansible_facts_modified < now() - datetime.timedelta(seconds=timeout)): if not host.ansible_facts_modified or (timeout and host.ansible_facts_modified < now() - datetime.timedelta(seconds=timeout)):
hosts_cached[host.name] = False
continue # facts are expired - do not write them continue # facts are expired - do not write them
filepath = os.path.join(fact_cache_dir, host.name) filepath = os.path.join(fact_cache_dir, host.name)
if not os.path.realpath(filepath).startswith(fact_cache_dir): if not os.path.realpath(filepath).startswith(fact_cache_dir):
logger.error(f'facts for host {smart_str(host.name)} could not be cached') logger.error(f'facts for host {smart_str(host.name)} could not be cached')
hosts_cached[host.name] = False
continue continue
try: try:
@@ -53,18 +51,9 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
os.chmod(f.name, 0o600) os.chmod(f.name, 0o600)
json.dump(host.ansible_facts, f) json.dump(host.ansible_facts, f)
log_data['written_ct'] += 1 log_data['written_ct'] += 1
# Backdate the file by 2 seconds so finish_fact_cache can reliably last_write_time = os.path.getmtime(filepath)
# distinguish these reference files from files updated by ansible.
# This guarantees fact file mtime < summary file mtime even with
# zipfile's 2-second timestamp rounding during artifact transfer.
mtime = os.path.getmtime(filepath)
backdated = mtime - 2
os.utime(filepath, (backdated, backdated))
last_write_time = backdated
hosts_cached[host.name] = True
except IOError: except IOError:
logger.error(f'facts for host {smart_str(host.name)} could not be cached') logger.error(f'facts for host {smart_str(host.name)} could not be cached')
hosts_cached[host.name] = False
continue continue
# Write summary file directly to the artifacts_dir # Write summary file directly to the artifacts_dir
@@ -73,6 +62,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
summary_data = { summary_data = {
'last_write_time': last_write_time, 'last_write_time': last_write_time,
'hosts_cached': hosts_cached, 'hosts_cached': hosts_cached,
'written_ct': log_data['written_ct'],
} }
with open(summary_file, 'w', encoding='utf-8') as f: with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(summary_data, f, indent=2) json.dump(summary_data, f, indent=2)
@@ -84,7 +74,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
msg='Inventory {inventory_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s', msg='Inventory {inventory_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s',
add_log_data=True, add_log_data=True,
) )
def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, job_created=None, log_data=None): def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=None):
log_data = log_data or {} log_data = log_data or {}
log_data['inventory_id'] = inventory_id log_data['inventory_id'] = inventory_id
log_data['updated_ct'] = 0 log_data['updated_ct'] = 0
@@ -104,9 +94,8 @@ def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, jo
logger.error(f'Error reading summary file at {summary_path}: {e}') logger.error(f'Error reading summary file at {summary_path}: {e}')
return return
hosts_cached_map = summary.get('hosts_cached', {}) host_names = summary.get('hosts_cached', [])
host_names = list(hosts_cached_map.keys()) hosts_cached = Host.objects.filter(name__in=host_names).order_by('id').iterator()
hosts_cached = host_qs.filter(name__in=host_names).order_by('id').iterator()
# Path where individual fact files were written # Path where individual fact files were written
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache') fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
hosts_to_update = [] hosts_to_update = []
@@ -147,35 +136,16 @@ def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, jo
else: else:
log_data['unmodified_ct'] += 1 log_data['unmodified_ct'] += 1
else: else:
# File is missing. Only interpret this as "ansible cleared facts" if
# start_fact_cache actually wrote a file for this host (i.e. the host
# had valid, non-expired facts before the job ran). If no file was
# ever written, the missing file is expected and not a clear signal.
if not hosts_cached_map.get(host.name):
log_data['unmodified_ct'] += 1
continue
# if the file goes missing, ansible removed it (likely via clear_facts) # if the file goes missing, ansible removed it (likely via clear_facts)
# if the file goes missing, but the host has not started facts, then we should not clear the facts # if the file goes missing, but the host has not started facts, then we should not clear the facts
if job_created and host.ansible_facts_modified and host.ansible_facts_modified > job_created: host.ansible_facts = {}
logger.warning( host.ansible_facts_modified = now()
f'Skipping fact clear for host {smart_str(host.name)} in job {job_id} ' hosts_to_update.append(host)
f'inventory {inventory_id}: host ansible_facts_modified ' logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
f'({host.ansible_facts_modified.isoformat()}) is after this job\'s ' log_data['cleared_ct'] += 1
f'created time ({job_created.isoformat()}). '
f'A concurrent job likely updated this host\'s facts while this job was running.'
)
log_data['unmodified_ct'] += 1
else:
host.ansible_facts = {}
host.ansible_facts_modified = now()
hosts_to_update.append(host)
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
log_data['cleared_ct'] += 1
if len(hosts_to_update) >= 100: if len(hosts_to_update) >= 100:
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified']) bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
hosts_to_update = [] hosts_to_update = []
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified']) bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
logger.debug(f'Updated {log_data["updated_ct"]} host facts for inventory {inventory_id} in job {job_id}')

View File

@@ -17,6 +17,7 @@ import urllib.parse as urlparse
# Django # Django
from django.conf import settings from django.conf import settings
from django.db import transaction
# Shared code for the AWX platform # Shared code for the AWX platform
from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT, get_incontainer_path from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT, get_incontainer_path
@@ -83,7 +84,6 @@ from awx.main.utils.common import (
create_partition, create_partition,
ScheduleWorkflowManager, ScheduleWorkflowManager,
ScheduleTaskManager, ScheduleTaskManager,
getattr_dne,
) )
from awx.conf.license import get_license from awx.conf.license import get_license
from awx.main.utils.handlers import SpecialInventoryHandler from awx.main.utils.handlers import SpecialInventoryHandler
@@ -92,92 +92,9 @@ from awx.main.utils.update_model import update_model
# Django flags # Django flags
from flags.state import flag_enabled from flags.state import flag_enabled
# Workload Identity
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client
logger = logging.getLogger('awx.main.tasks.jobs') logger = logging.getLogger('awx.main.tasks.jobs')
def populate_claims_for_workload(unified_job) -> dict:
"""
Extract JWT claims from a Controller workload for the aap_controller_automation_job scope.
"""
claims = {
AutomationControllerJobScope.CLAIM_JOB_ID: unified_job.id,
AutomationControllerJobScope.CLAIM_JOB_NAME: unified_job.name,
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: unified_job.launch_type,
}
# Related objects in the UnifiedJob model, applies to all job types
# null cases are omitted because of OIDC
if organization := getattr_dne(unified_job, 'organization'):
claims[AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME] = organization.name
claims[AutomationControllerJobScope.CLAIM_ORGANIZATION_ID] = organization.id
if ujt := getattr_dne(unified_job, 'unified_job_template'):
claims[AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME] = ujt.name
claims[AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID] = ujt.id
if instance_group := getattr_dne(unified_job, 'instance_group'):
claims[AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME] = instance_group.name
claims[AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID] = instance_group.id
# Related objects on concrete models, may not be valid for type of unified_job
if inventory := getattr_dne(unified_job, 'inventory', None):
claims[AutomationControllerJobScope.CLAIM_INVENTORY_NAME] = inventory.name
claims[AutomationControllerJobScope.CLAIM_INVENTORY_ID] = inventory.id
if execution_environment := getattr_dne(unified_job, 'execution_environment', None):
claims[AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME] = execution_environment.name
claims[AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID] = execution_environment.id
if project := getattr_dne(unified_job, 'project', None):
claims[AutomationControllerJobScope.CLAIM_PROJECT_NAME] = project.name
claims[AutomationControllerJobScope.CLAIM_PROJECT_ID] = project.id
if jt := getattr_dne(unified_job, 'job_template', None):
claims[AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME] = jt.name
claims[AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID] = jt.id
# Only valid for job templates
if hasattr(unified_job, 'playbook'):
claims[AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME] = unified_job.playbook
# Not valid for inventory updates and system jobs
if hasattr(unified_job, 'job_type'):
claims[AutomationControllerJobScope.CLAIM_JOB_TYPE] = unified_job.job_type
launched_by: dict = unified_job.launched_by
if 'name' in launched_by:
claims[AutomationControllerJobScope.CLAIM_LAUNCHED_BY_NAME] = launched_by['name']
if 'id' in launched_by:
claims[AutomationControllerJobScope.CLAIM_LAUNCHED_BY_ID] = launched_by['id']
return claims
def retrieve_workload_identity_jwt(
unified_job: UnifiedJob,
audience: str,
scope: str,
workload_ttl_seconds: int | None = None,
) -> str:
"""Retrieve JWT token from workload claims.
Raises:
RuntimeError: if the workload identity client is not configured.
"""
client = get_workload_identity_client()
if client is None:
raise RuntimeError("Workload identity client is not configured")
claims = populate_claims_for_workload(unified_job)
kwargs = {"claims": claims, "scope": scope, "audience": audience}
if workload_ttl_seconds:
kwargs["workload_ttl_seconds"] = workload_ttl_seconds
return client.request_workload_jwt(**kwargs).jwt
def with_path_cleanup(f): def with_path_cleanup(f):
@functools.wraps(f) @functools.wraps(f)
def _wrapped(self, *args, **kwargs): def _wrapped(self, *args, **kwargs):
@@ -204,7 +121,6 @@ def dispatch_waiting_jobs(binder):
if not kwargs: if not kwargs:
kwargs = {} kwargs = {}
binder.control('run', data={'task': serialize_task(uj._get_task_class()), 'args': [uj.id], 'kwargs': kwargs, 'uuid': uj.celery_task_id}) binder.control('run', data={'task': serialize_task(uj._get_task_class()), 'args': [uj.id], 'kwargs': kwargs, 'uuid': uj.celery_task_id})
UnifiedJob.objects.filter(pk=uj.pk, status='waiting').update(status='running', start_args='')
class BaseTask(object): class BaseTask(object):
@@ -219,60 +135,6 @@ class BaseTask(object):
self.update_attempts = int(getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE) / 5) self.update_attempts = int(getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE) / 5)
self.runner_callback = self.callback_class(model=self.model) self.runner_callback = self.callback_class(model=self.model)
@functools.cached_property
def _credentials(self):
"""
Credentials for the task execution.
Fetches credentials once using build_credentials_list() and stores
them for the duration of the task to avoid redundant database queries.
"""
credentials_list = self.build_credentials_list(self.instance)
# Convert to list to prevent re-evaluation of QuerySet
return list(credentials_list)
def populate_workload_identity_tokens(self):
"""
Populate credentials with workload identity tokens.
Sets the context on Credential objects that have input sources
using compatible external credential types.
"""
credential_input_sources = (
(credential.context, src)
for credential in self._credentials
for src in credential.input_sources.all()
if any(
field.get('id') == 'workload_identity_token' and field.get('internal')
for field in src.source_credential.credential_type.inputs.get('fields', [])
)
)
for credential_ctx, input_src in credential_input_sources:
if flag_enabled("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED"):
effective_timeout = self.get_instance_timeout(self.instance)
workload_ttl = effective_timeout if effective_timeout else None
try:
jwt = retrieve_workload_identity_jwt(
self.instance,
audience=input_src.source_credential.get_input('jwt_aud'),
scope=AutomationControllerJobScope.name,
workload_ttl_seconds=workload_ttl,
)
# Store token keyed by input source PK, since a credential can have
# multiple input sources (one per field), each potentially with a different audience
credential_ctx[input_src.pk] = {"workload_identity_token": jwt}
except Exception as e:
self.instance.job_explanation = (
f'Could not generate workload identity token for credential {input_src.source_credential.name} used in this job. Error:\n{e}'
)
self.instance.status = 'error'
self.instance.save()
else:
self.instance.job_explanation = (
f'Flag FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is not enabled, required for credential {input_src.source_credential.name} used in this job.'
)
self.instance.status = 'error'
self.instance.save()
def update_model(self, pk, _attempt=0, **updates): def update_model(self, pk, _attempt=0, **updates):
return update_model(self.model, pk, _attempt=0, _max_attempts=self.update_attempts, **updates) return update_model(self.model, pk, _attempt=0, _max_attempts=self.update_attempts, **updates)
@@ -424,19 +286,6 @@ class BaseTask(object):
private_data_files['credentials'][credential] = self.write_private_data_file(private_data_dir, None, data, sub_dir='env') private_data_files['credentials'][credential] = self.write_private_data_file(private_data_dir, None, data, sub_dir='env')
for credential, data in private_data.get('certificates', {}).items(): for credential, data in private_data.get('certificates', {}).items():
self.write_private_data_file(private_data_dir, 'ssh_key_data-cert.pub', data, sub_dir=os.path.join('artifacts', str(self.instance.id))) self.write_private_data_file(private_data_dir, 'ssh_key_data-cert.pub', data, sub_dir=os.path.join('artifacts', str(self.instance.id)))
# Copy vendor collections to private_data_dir for indirect node counting
# This makes external query files available to the callback plugin in EEs
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
vendor_src = '/var/lib/awx/vendor_collections'
vendor_dest = os.path.join(private_data_dir, 'vendor_collections')
if os.path.exists(vendor_src):
try:
shutil.copytree(vendor_src, vendor_dest)
logger.debug(f"Copied vendor collections from {vendor_src} to {vendor_dest}")
except Exception as e:
logger.warning(f"Failed to copy vendor collections: {e}")
return private_data_files, ssh_key_data return private_data_files, ssh_key_data
def build_passwords(self, instance, runtime_passwords): def build_passwords(self, instance, runtime_passwords):
@@ -510,7 +359,6 @@ class BaseTask(object):
return [] return []
def get_instance_timeout(self, instance): def get_instance_timeout(self, instance):
"""Return the effective job timeout in seconds."""
global_timeout_setting_name = instance._global_timeout_setting() global_timeout_setting_name = instance._global_timeout_setting()
if global_timeout_setting_name: if global_timeout_setting_name:
global_timeout = getattr(settings, global_timeout_setting_name, 0) global_timeout = getattr(settings, global_timeout_setting_name, 0)
@@ -619,32 +467,48 @@ class BaseTask(object):
def should_use_fact_cache(self): def should_use_fact_cache(self):
return False return False
def transition_status(self, pk: int) -> bool:
"""Atomically transition status to running, if False returned, another process got it"""
with transaction.atomic():
# Explanation of parts for the fetch:
# .values - avoid loading a full object, this is known to lead to deadlocks due to signals
# the signals load other related rows which another process may be locking, and happens in practice
# of=('self',) - keeps FK tables out of the lock list, another way deadlocks can happen
# .get - just load the single job
instance_data = UnifiedJob.objects.select_for_update(of=('self',)).values('status', 'cancel_flag').get(pk=pk)
# If status is not waiting (obtained under lock) then this process does not have clearence to run
if instance_data['status'] == 'waiting':
if instance_data['cancel_flag']:
updated_status = 'canceled'
else:
updated_status = 'running'
# Explanation of the update:
# .filter - again, do not load the full object
# .update - a bulk update on just that one row, avoid loading unintended data
UnifiedJob.objects.filter(pk=pk).update(status=updated_status, start_args='')
elif instance_data['status'] == 'running':
logger.info(f'Job {pk} is being ran by another process, exiting')
return False
return True
@with_path_cleanup @with_path_cleanup
@with_signal_handling @with_signal_handling
def run(self, pk, **kwargs): def run(self, pk, **kwargs):
""" """
Run the job/task and capture its output. Run the job/task and capture its output.
""" """
if not self.instance: # Used to skip fetch for local runs if not self.instance: # Used to skip fetch for local runs
# Load the instance if not self.transition_status(pk):
self.instance = self.update_model(pk) logger.info(f'Job {pk} is being ran by another process, exiting')
return
# status should be "running" from dispatch_waiting_jobs,
# but may still be "waiting" if the worker picked this up before the status update landed.
if self.instance.status == 'waiting':
UnifiedJob.objects.filter(pk=pk).update(status="running", start_args='')
self.instance.refresh_from_db()
# Load the instance
self.instance = self.update_model(pk)
if self.instance.status != 'running': if self.instance.status != 'running':
logger.error(f'Not starting {self.instance.status} task pk={pk} because its status "{self.instance.status}" is not expected') logger.error(f'Not starting {self.instance.status} task pk={pk} because its status "{self.instance.status}" is not expected')
return return
if self.instance.cancel_flag:
self.instance = self.update_model(pk, status='canceled')
self.instance.websocket_emit_status('canceled')
return
self.instance.websocket_emit_status("running") self.instance.websocket_emit_status("running")
status, rc = 'error', None status, rc = 'error', None
self.runner_callback.event_ct = 0 self.runner_callback.event_ct = 0
@@ -683,12 +547,6 @@ class BaseTask(object):
if not os.path.exists(settings.AWX_ISOLATION_BASE_PATH): if not os.path.exists(settings.AWX_ISOLATION_BASE_PATH):
raise RuntimeError('AWX_ISOLATION_BASE_PATH=%s does not exist' % settings.AWX_ISOLATION_BASE_PATH) raise RuntimeError('AWX_ISOLATION_BASE_PATH=%s does not exist' % settings.AWX_ISOLATION_BASE_PATH)
if flag_enabled("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED"):
logger.info(f'Generating workload identity tokens for {self.instance.log_format}')
self.populate_workload_identity_tokens()
if self.instance.status == 'error':
raise RuntimeError('not starting %s task' % self.instance.status)
# May have to serialize the value # May have to serialize the value
private_data_files, ssh_key_data = self.build_private_data_files(self.instance, private_data_dir) private_data_files, ssh_key_data = self.build_private_data_files(self.instance, private_data_dir)
passwords = self.build_passwords(self.instance, kwargs) passwords = self.build_passwords(self.instance, kwargs)
@@ -706,7 +564,7 @@ class BaseTask(object):
self.runner_callback.job_created = str(self.instance.created) self.runner_callback.job_created = str(self.instance.created)
credentials = self._credentials credentials = self.build_credentials_list(self.instance)
container_root = None container_root = None
if settings.IS_K8S and isinstance(self.instance, ProjectUpdate): if settings.IS_K8S and isinstance(self.instance, ProjectUpdate):
@@ -1001,29 +859,6 @@ class RunJob(SourceControlMixin, BaseTask):
model = Job model = Job
event_model = JobEvent event_model = JobEvent
def _extract_credentials_of_kind(self, kind: str):
return (cred for cred in self._credentials if cred.credential_type.kind == kind)
@property
def _machine_credential(self) -> object:
"""Get machine credential."""
return next(self._extract_credentials_of_kind('ssh'), None)
@property
def _vault_credentials(self) -> list[object]:
"""Get vault credentials."""
return list(self._extract_credentials_of_kind('vault'))
@property
def _network_credentials(self) -> list[object]:
"""Get network credentials."""
return list(self._extract_credentials_of_kind('net'))
@property
def _cloud_credentials(self) -> list[object]:
"""Get cloud credentials."""
return list(self._extract_credentials_of_kind('cloud'))
def build_private_data(self, job, private_data_dir): def build_private_data(self, job, private_data_dir):
""" """
Returns a dict of the form Returns a dict of the form
@@ -1041,7 +876,7 @@ class RunJob(SourceControlMixin, BaseTask):
} }
""" """
private_data = {'credentials': {}} private_data = {'credentials': {}}
for credential in self._credentials: for credential in job.credentials.prefetch_related('input_sources__source_credential').all():
# If we were sent SSH credentials, decrypt them and send them # If we were sent SSH credentials, decrypt them and send them
# back (they will be written to a temporary file). # back (they will be written to a temporary file).
if credential.has_input('ssh_key_data'): if credential.has_input('ssh_key_data'):
@@ -1057,14 +892,14 @@ class RunJob(SourceControlMixin, BaseTask):
and ansible-vault. and ansible-vault.
""" """
passwords = super(RunJob, self).build_passwords(job, runtime_passwords) passwords = super(RunJob, self).build_passwords(job, runtime_passwords)
cred = self._machine_credential cred = job.machine_credential
if cred: if cred:
for field in ('ssh_key_unlock', 'ssh_password', 'become_password', 'vault_password'): for field in ('ssh_key_unlock', 'ssh_password', 'become_password', 'vault_password'):
value = runtime_passwords.get(field, cred.get_input('password' if field == 'ssh_password' else field, default='')) value = runtime_passwords.get(field, cred.get_input('password' if field == 'ssh_password' else field, default=''))
if value not in ('', 'ASK'): if value not in ('', 'ASK'):
passwords[field] = value passwords[field] = value
for cred in self._vault_credentials: for cred in job.vault_credentials:
field = 'vault_password' field = 'vault_password'
vault_id = cred.get_input('vault_id', default=None) vault_id = cred.get_input('vault_id', default=None)
if vault_id: if vault_id:
@@ -1080,7 +915,7 @@ class RunJob(SourceControlMixin, BaseTask):
key unlock over network key unlock. key unlock over network key unlock.
''' '''
if 'ssh_key_unlock' not in passwords: if 'ssh_key_unlock' not in passwords:
for cred in self._network_credentials: for cred in job.network_credentials:
if cred.inputs.get('ssh_key_unlock'): if cred.inputs.get('ssh_key_unlock'):
passwords['ssh_key_unlock'] = runtime_passwords.get('ssh_key_unlock', cred.get_input('ssh_key_unlock', default='')) passwords['ssh_key_unlock'] = runtime_passwords.get('ssh_key_unlock', cred.get_input('ssh_key_unlock', default=''))
break break
@@ -1115,11 +950,11 @@ class RunJob(SourceControlMixin, BaseTask):
# Set environment variables for cloud credentials. # Set environment variables for cloud credentials.
cred_files = private_data_files.get('credentials', {}) cred_files = private_data_files.get('credentials', {})
for cloud_cred in self._cloud_credentials: for cloud_cred in job.cloud_credentials:
if cloud_cred and cloud_cred.credential_type.namespace == 'openstack' and cred_files.get(cloud_cred, ''): if cloud_cred and cloud_cred.credential_type.namespace == 'openstack' and cred_files.get(cloud_cred, ''):
env['OS_CLIENT_CONFIG_FILE'] = get_incontainer_path(cred_files.get(cloud_cred, ''), private_data_dir) env['OS_CLIENT_CONFIG_FILE'] = get_incontainer_path(cred_files.get(cloud_cred, ''), private_data_dir)
for network_cred in self._network_credentials: for network_cred in job.network_credentials:
env['ANSIBLE_NET_USERNAME'] = network_cred.get_input('username', default='') env['ANSIBLE_NET_USERNAME'] = network_cred.get_input('username', default='')
env['ANSIBLE_NET_PASSWORD'] = network_cred.get_input('password', default='') env['ANSIBLE_NET_PASSWORD'] = network_cred.get_input('password', default='')
@@ -1162,11 +997,6 @@ class RunJob(SourceControlMixin, BaseTask):
if 'callbacks_enabled' in config_values: if 'callbacks_enabled' in config_values:
env['ANSIBLE_CALLBACKS_ENABLED'] += ':' + config_values['callbacks_enabled'] env['ANSIBLE_CALLBACKS_ENABLED'] += ':' + config_values['callbacks_enabled']
# Add vendor collections path for external query file discovery
vendor_collections_path = os.path.join(CONTAINER_ROOT, 'vendor_collections')
env['ANSIBLE_COLLECTIONS_PATH'] = f"{vendor_collections_path}:{env['ANSIBLE_COLLECTIONS_PATH']}"
logger.debug(f"ANSIBLE_COLLECTIONS_PATH updated for vendor collections: {env['ANSIBLE_COLLECTIONS_PATH']}")
return env return env
def build_args(self, job, private_data_dir, passwords): def build_args(self, job, private_data_dir, passwords):
@@ -1174,7 +1004,7 @@ class RunJob(SourceControlMixin, BaseTask):
Build command line argument list for running ansible-playbook, Build command line argument list for running ansible-playbook,
optionally using ssh-agent for public/private key authentication. optionally using ssh-agent for public/private key authentication.
""" """
creds = self._machine_credential creds = job.machine_credential
ssh_username, become_username, become_method = '', '', '' ssh_username, become_username, become_method = '', '', ''
if creds: if creds:
@@ -1326,16 +1156,10 @@ class RunJob(SourceControlMixin, BaseTask):
return return
if self.should_use_fact_cache() and self.runner_callback.artifacts_processed: if self.should_use_fact_cache() and self.runner_callback.artifacts_processed:
job.log_lifecycle("finish_job_fact_cache") job.log_lifecycle("finish_job_fact_cache")
if job.inventory.kind == 'constructed':
hosts_qs = job.get_source_hosts_for_constructed_inventory()
else:
hosts_qs = job.inventory.hosts
finish_fact_cache( finish_fact_cache(
hosts_qs,
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)), artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)),
job_id=job.id, job_id=job.id,
inventory_id=job.inventory_id, inventory_id=job.inventory_id,
job_created=job.created,
) )
def final_run_hook(self, job, status, private_data_dir): def final_run_hook(self, job, status, private_data_dir):
@@ -1504,6 +1328,7 @@ class RunProjectUpdate(BaseTask):
'local_path': os.path.basename(project_update.project.local_path), 'local_path': os.path.basename(project_update.project.local_path),
'project_path': project_update.get_project_path(check_if_exists=False), # deprecated 'project_path': project_update.get_project_path(check_if_exists=False), # deprecated
'insights_url': settings.INSIGHTS_URL_BASE, 'insights_url': settings.INSIGHTS_URL_BASE,
'oidc_endpoint': settings.INSIGHTS_OIDC_ENDPOINT,
'awx_license_type': get_license().get('license_type', 'UNLICENSED'), 'awx_license_type': get_license().get('license_type', 'UNLICENSED'),
'awx_version': get_awx_version(), 'awx_version': get_awx_version(),
'scm_url': scm_url, 'scm_url': scm_url,

View File

@@ -393,9 +393,9 @@ def evaluate_policy(instance):
raise PolicyEvaluationError(_('Following certificate settings are missing for OPA_AUTH_TYPE=Certificate: {}').format(cert_settings_missing)) raise PolicyEvaluationError(_('Following certificate settings are missing for OPA_AUTH_TYPE=Certificate: {}').format(cert_settings_missing))
query_paths = [ query_paths = [
('Organization', instance.organization.opa_query_path if instance.organization else None), ('Organization', instance.organization.opa_query_path),
('Inventory', instance.inventory.opa_query_path if instance.inventory else None), ('Inventory', instance.inventory.opa_query_path),
('Job template', instance.job_template.opa_query_path if instance.job_template else None), ('Job template', instance.job_template.opa_query_path),
] ]
violations = dict() violations = dict()
errors = dict() errors = dict()

View File

@@ -69,7 +69,7 @@ def signal_callback():
def with_signal_handling(f): def with_signal_handling(f):
""" """
Change signal handling to make signal_callback return True in event of SIGTERM, SIGINT, or SIGUSR1. Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT.
""" """
@functools.wraps(f) @functools.wraps(f)

View File

@@ -93,10 +93,7 @@ def _run_dispatch_startup_common():
# TODO: Enable this on VM installs # TODO: Enable this on VM installs
if settings.IS_K8S: if settings.IS_K8S:
try: write_receptor_config()
write_receptor_config()
except Exception:
logger.exception("Failed to write receptor config, skipping.")
try: try:
convert_jsonfields() convert_jsonfields()
@@ -760,16 +757,14 @@ def _heartbeat_check_versions(this_inst, instance_list):
def _heartbeat_handle_lost_instances(lost_instances, this_inst): def _heartbeat_handle_lost_instances(lost_instances, this_inst):
"""Handle lost instances by reaping their running jobs and marking them offline.""" """Handle lost instances by reaping their jobs and marking them offline."""
for other_inst in lost_instances: for other_inst in lost_instances:
try: try:
# Any jobs marked as running will be marked as error
explanation = "Job reaped due to instance shutdown" explanation = "Job reaped due to instance shutdown"
reaper.reap(other_inst, job_explanation=explanation) reaper.reap(other_inst, job_explanation=explanation)
# Any jobs that were waiting to be processed by this node will be handed back to task manager reaper.reap_waiting(other_inst, grace_period=0, job_explanation=explanation)
UnifiedJob.objects.filter(status='waiting', controller_node=other_inst.hostname).update(status='pending', controller_node='', execution_node='')
except Exception: except Exception:
logger.exception('failed to re-process jobs for lost instance {}'.format(other_inst.hostname)) logger.exception('failed to reap jobs for {}'.format(other_inst.hostname))
try: try:
if settings.AWX_AUTO_DEPROVISION_INSTANCES and other_inst.node_type == "control": if settings.AWX_AUTO_DEPROVISION_INSTANCES and other_inst.node_type == "control":
deprovision_hostname = other_inst.hostname deprovision_hostname = other_inst.hostname

View File

@@ -1,19 +0,0 @@
---
authors:
- AWX Project Contributors <awx-project@googlegroups.com>
dependencies: {}
description: External query testing collection. No embedded query file. Not for use in production.
documentation: https://github.com/ansible/awx
homepage: https://github.com/ansible/awx
issues: https://github.com/ansible/awx
license:
- GPL-3.0-or-later
name: external
namespace: demo
readme: README.md
repository: https://github.com/ansible/awx
tags:
- demo
- testing
- external_query
version: 1.0.0

View File

@@ -1,78 +0,0 @@
#!/usr/bin/python
# Same licensing as AWX
from __future__ import absolute_import, division, print_function
__metaclass__ = type
DOCUMENTATION = r'''
---
module: example
short_description: Module for specific live tests
version_added: "2.0.0"
description: This module is part of a test collection in local source. Used for external query testing.
options:
host_name:
description: Name to return as the host name.
required: false
type: str
author:
- AWX Live Tests
'''
EXAMPLES = r'''
- name: Test with defaults
demo.external.example:
- name: Test with custom host name
demo.external.example:
host_name: foo_host
'''
RETURN = r'''
direct_host_name:
description: The name of the host, this will be collected with the feature.
type: str
returned: always
sample: 'foo_host'
'''
from ansible.module_utils.basic import AnsibleModule
def run_module():
module_args = dict(
host_name=dict(type='str', required=False, default='foo_host_default'),
)
result = dict(
changed=False,
other_data='sample_string',
)
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
if module.check_mode:
module.exit_json(**result)
result['direct_host_name'] = module.params['host_name']
result['nested_host_name'] = {'host_name': module.params['host_name']}
result['name'] = 'vm-foo'
# non-cononical facts
result['device_type'] = 'Fake Host'
module.exit_json(**result)
def main():
run_module()
if __name__ == '__main__':
main()

View File

@@ -1,19 +0,0 @@
---
authors:
- AWX Project Contributors <awx-project@googlegroups.com>
dependencies: {}
description: External query testing collection v1.5.0. No embedded query file. Not for use in production.
documentation: https://github.com/ansible/awx
homepage: https://github.com/ansible/awx
issues: https://github.com/ansible/awx
license:
- GPL-3.0-or-later
name: external
namespace: demo
readme: README.md
repository: https://github.com/ansible/awx
tags:
- demo
- testing
- external_query
version: 1.5.0

View File

@@ -1,78 +0,0 @@
#!/usr/bin/python
# Same licensing as AWX
from __future__ import absolute_import, division, print_function
__metaclass__ = type
DOCUMENTATION = r'''
---
module: example
short_description: Module for specific live tests
version_added: "2.0.0"
description: This module is part of a test collection in local source. Used for external query testing.
options:
host_name:
description: Name to return as the host name.
required: false
type: str
author:
- AWX Live Tests
'''
EXAMPLES = r'''
- name: Test with defaults
demo.external.example:
- name: Test with custom host name
demo.external.example:
host_name: foo_host
'''
RETURN = r'''
direct_host_name:
description: The name of the host, this will be collected with the feature.
type: str
returned: always
sample: 'foo_host'
'''
from ansible.module_utils.basic import AnsibleModule
def run_module():
module_args = dict(
host_name=dict(type='str', required=False, default='foo_host_default'),
)
result = dict(
changed=False,
other_data='sample_string',
)
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
if module.check_mode:
module.exit_json(**result)
result['direct_host_name'] = module.params['host_name']
result['nested_host_name'] = {'host_name': module.params['host_name']}
result['name'] = 'vm-foo'
# non-cononical facts
result['device_type'] = 'Fake Host'
module.exit_json(**result)
def main():
run_module()
if __name__ == '__main__':
main()

View File

@@ -1,19 +0,0 @@
---
authors:
- AWX Project Contributors <awx-project@googlegroups.com>
dependencies: {}
description: External query testing collection v3.0.0. No embedded query file. Not for use in production.
documentation: https://github.com/ansible/awx
homepage: https://github.com/ansible/awx
issues: https://github.com/ansible/awx
license:
- GPL-3.0-or-later
name: external
namespace: demo
readme: README.md
repository: https://github.com/ansible/awx
tags:
- demo
- testing
- external_query
version: 3.0.0

View File

@@ -1,78 +0,0 @@
#!/usr/bin/python
# Same licensing as AWX
from __future__ import absolute_import, division, print_function
__metaclass__ = type
DOCUMENTATION = r'''
---
module: example
short_description: Module for specific live tests
version_added: "2.0.0"
description: This module is part of a test collection in local source. Used for external query testing.
options:
host_name:
description: Name to return as the host name.
required: false
type: str
author:
- AWX Live Tests
'''
EXAMPLES = r'''
- name: Test with defaults
demo.external.example:
- name: Test with custom host name
demo.external.example:
host_name: foo_host
'''
RETURN = r'''
direct_host_name:
description: The name of the host, this will be collected with the feature.
type: str
returned: always
sample: 'foo_host'
'''
from ansible.module_utils.basic import AnsibleModule
def run_module():
module_args = dict(
host_name=dict(type='str', required=False, default='foo_host_default'),
)
result = dict(
changed=False,
other_data='sample_string',
)
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
if module.check_mode:
module.exit_json(**result)
result['direct_host_name'] = module.params['host_name']
result['nested_host_name'] = {'host_name': module.params['host_name']}
result['name'] = 'vm-foo'
# non-cononical facts
result['device_type'] = 'Fake Host'
module.exit_json(**result)
def main():
run_module()
if __name__ == '__main__':
main()

View File

@@ -1,21 +0,0 @@
---
# Generated by Claude Opus 4.6 (claude-opus-4-6).
- hosts: all
vars:
extra_value: ""
gather_facts: false
connection: local
tasks:
- name: set a custom fact
set_fact:
foo: "bar{{ extra_value }}"
bar:
a:
b:
- "c"
- "d"
cacheable: true
- name: sleep to create overlap window for concurrent job testing
wait_for:
timeout: 2

View File

@@ -1,5 +0,0 @@
---
collections:
- name: 'file:///tmp/live_tests/host_query_external_v1_0_0'
type: git
version: devel

View File

@@ -1,8 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- demo.external.example:
register: result
- debug: var=result

View File

@@ -1,5 +0,0 @@
---
collections:
- name: 'file:///tmp/live_tests/host_query_external_v1_5_0'
type: git
version: devel

View File

@@ -1,8 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- demo.external.example:
register: result
- debug: var=result

View File

@@ -1,5 +0,0 @@
---
collections:
- name: 'file:///tmp/live_tests/host_query_external_v3_0_0'
type: git
version: devel

View File

@@ -1,8 +0,0 @@
---
- hosts: all
gather_facts: false
connection: local
tasks:
- demo.external.example:
register: result
- debug: var=result

View File

@@ -1,11 +1,8 @@
import pytest import pytest
from django.test import RequestFactory
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from rest_framework.request import Request
from awx.main import models from awx.main import models
from awx.main.analytics.metrics import metrics from awx.main.analytics.metrics import metrics
from awx.main.analytics.dispatcherd_metrics import get_dispatcherd_metrics
from awx.api.versioning import reverse from awx.api.versioning import reverse
EXPECTED_VALUES = { EXPECTED_VALUES = {
@@ -80,55 +77,3 @@ def test_metrics_http_methods(get, post, patch, put, options, admin):
assert patch(get_metrics_view_db_only(), user=admin).status_code == 405 assert patch(get_metrics_view_db_only(), user=admin).status_code == 405
assert post(get_metrics_view_db_only(), user=admin).status_code == 405 assert post(get_metrics_view_db_only(), user=admin).status_code == 405
assert options(get_metrics_view_db_only(), user=admin).status_code == 200 assert options(get_metrics_view_db_only(), user=admin).status_code == 200
class DummyMetricsResponse:
def __init__(self, payload):
self._payload = payload
def read(self):
return self._payload
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def test_dispatcherd_metrics_node_filter_match(mocker, settings):
settings.CLUSTER_HOST_ID = "awx-1"
payload = b'# HELP test_metric A test metric\n# TYPE test_metric gauge\ntest_metric 1\n'
def fake_urlopen(url, timeout=1.0):
return DummyMetricsResponse(payload)
mocker.patch('urllib.request.urlopen', fake_urlopen)
request = Request(RequestFactory().get('/api/v2/metrics/', {'node': 'awx-1'}))
assert get_dispatcherd_metrics(request) == payload.decode('utf-8')
def test_dispatcherd_metrics_node_filter_excludes_local(mocker, settings):
settings.CLUSTER_HOST_ID = "awx-1"
def fake_urlopen(*args, **kwargs):
raise AssertionError("urlopen should not be called when node filter excludes local node")
mocker.patch('urllib.request.urlopen', fake_urlopen)
request = Request(RequestFactory().get('/api/v2/metrics/', {'node': 'awx-2'}))
assert get_dispatcherd_metrics(request) == ''
def test_dispatcherd_metrics_metric_filter_excludes_unrelated(mocker):
def fake_urlopen(*args, **kwargs):
raise AssertionError("urlopen should not be called when metric filter excludes dispatcherd metrics")
mocker.patch('urllib.request.urlopen', fake_urlopen)
request = Request(RequestFactory().get('/api/v2/metrics/', {'metric': 'awx_system_info'}))
assert get_dispatcherd_metrics(request) == ''

View File

@@ -1,7 +1,5 @@
import pytest import pytest
from ansible_base.lib.testing.util import feature_flag_enabled, feature_flag_disabled
from awx.main.models import CredentialInputSource from awx.main.models import CredentialInputSource
from awx.api.versioning import reverse from awx.api.versioning import reverse
@@ -318,60 +316,3 @@ def test_create_credential_input_source_with_already_used_input_returns_400(post
] ]
all_responses = [post(list_url, params, admin) for params in all_params] all_responses = [post(list_url, params, admin) for params in all_params]
assert all_responses.pop().status_code == 400 assert all_responses.pop().status_code == 400
@pytest.mark.django_db
def test_credential_input_source_passes_workload_identity_token_when_flag_enabled(vault_credential, external_credential, mocker):
"""Test that workload_identity_token is passed to backend when flag is enabled."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
# Add workload_identity_token as an internal field on the external credential type
# so get_input_value resolves it from the per-input-source context
external_credential.credential_type.inputs['fields'].append(
{'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'internal': True}
)
# Create an input source
input_source = CredentialInputSource.objects.create(
target_credential=vault_credential,
source_credential=external_credential,
input_field_name='vault_password',
metadata={'key': 'test_key'},
)
# Mock the credential plugin backend
mock_backend = mocker.patch.object(external_credential.credential_type.plugin, 'backend', autospec=True, return_value='test_value')
# Call with context keyed by input source PK
test_context = {input_source.pk: {'workload_identity_token': 'jwt_token_here'}}
result = input_source.get_input_value(context=test_context)
# Verify backend was called with workload_identity_token
assert result == 'test_value'
call_kwargs = mock_backend.call_args[1]
assert call_kwargs['workload_identity_token'] == 'jwt_token_here'
assert call_kwargs['key'] == 'test_key'
@pytest.mark.django_db
def test_credential_input_source_skips_workload_identity_token_when_flag_disabled(vault_credential, external_credential, mocker):
"""Test that workload_identity_token is NOT passed when flag is disabled."""
with feature_flag_disabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
# Create an input source
input_source = CredentialInputSource.objects.create(
target_credential=vault_credential,
source_credential=external_credential,
input_field_name='vault_password',
metadata={'key': 'test_key'},
)
# Mock the credential plugin backend
mock_backend = mocker.patch.object(external_credential.credential_type.plugin, 'backend', autospec=True, return_value='test_value')
# Call with context containing workload_identity_token but NO internal field defined,
# simulating a flag-disabled scenario where tokens are not generated upstream
test_context = {input_source.pk: {'workload_identity_token': 'jwt_token_here'}}
result = input_source.get_input_value(context=test_context)
# Verify backend was called WITHOUT workload_identity_token since the credential type
# does not define it as an internal field (flag-disabled path doesn't register it)
assert result == 'test_value'
call_kwargs = mock_backend.call_args[1]
assert 'workload_identity_token' not in call_kwargs
assert call_kwargs['key'] == 'test_key'

View File

@@ -463,26 +463,6 @@ class TestInventorySourceCredential:
assert 'Cloud-based inventory sources (such as ec2)' in r.data['credential'][0] assert 'Cloud-based inventory sources (such as ec2)' in r.data['credential'][0]
assert 'require credentials for the matching cloud service' in r.data['credential'][0] assert 'require credentials for the matching cloud service' in r.data['credential'][0]
def test_credential_dict_value_returns_400(self, inventory, admin_user, put):
"""Passing a dict for the credential field should return 400, not 500.
Reproduces a bug where int() raises TypeError on non-scalar types
(dict, list) which was uncaught, resulting in a 500 Internal Server Error.
"""
inv_src = InventorySource.objects.create(name='test-src', inventory=inventory, source='ec2')
r = put(
url=reverse('api:inventory_source_detail', kwargs={'pk': inv_src.pk}),
data={
'name': 'test-src',
'inventory': inventory.pk,
'source': 'ec2',
'credential': {'username': 'admin', 'password': 'secret'},
},
user=admin_user,
expect=400,
)
assert r.status_code == 400
def test_vault_credential_not_allowed(self, project, inventory, vault_credential, admin_user, post): def test_vault_credential_not_allowed(self, project, inventory, vault_credential, admin_user, post):
"""Vault credentials cannot be associated via the deprecated field""" """Vault credentials cannot be associated via the deprecated field"""
# TODO: when feature is added, add tests to use the related credentials # TODO: when feature is added, add tests to use the related credentials

View File

@@ -1,163 +0,0 @@
"""
Tests for OIDC workload identity credential type feature flag.
The FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED flag is an install-time flag that
controls whether OIDC credential types are loaded into the registry at startup.
When disabled, OIDC credential types are not loaded and do not exist in the database.
"""
import pytest
from unittest import mock
from django.test import override_settings
from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES
from awx.main.models.credential import CredentialType, ManagedCredentialType, load_credentials
from awx.api.versioning import reverse
@pytest.fixture
def reload_credentials_with_flag(django_db_setup, django_db_blocker):
"""
Fixture that reloads credentials with a specific flag state.
This simulates what happens at application startup.
"""
# Save original registry state
original_registry = ManagedCredentialType.registry.copy()
def _reload(flag_enabled):
with django_db_blocker.unblock():
# Clear the entire registry before reloading
ManagedCredentialType.registry.clear()
# Reload credentials with the specified flag state
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=flag_enabled):
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
load_credentials()
# Sync to database
CredentialType.setup_tower_managed_defaults(lock=False)
# In tests, the session fixture pre-loads all credential types into the DB.
# Remove OIDC types when testing the disabled state so the API test is accurate.
if not flag_enabled:
CredentialType.objects.filter(namespace__in=OIDC_CREDENTIAL_TYPE_NAMESPACES).delete()
yield _reload
# Restore original registry state after tests
ManagedCredentialType.registry.clear()
ManagedCredentialType.registry.update(original_registry)
@pytest.fixture
def isolated_registry():
"""Save and restore the ManagedCredentialType registry, with full isolation via mocked entry_points."""
original_registry = ManagedCredentialType.registry.copy()
ManagedCredentialType.registry.clear()
yield
ManagedCredentialType.registry.clear()
ManagedCredentialType.registry.update(original_registry)
def _make_mock_entry_point(name):
"""Create a mock entry point that mimics a credential plugin."""
ep = mock.MagicMock()
ep.name = name
ep.value = f'test_plugin:{name}'
plugin = mock.MagicMock(spec=[])
ep.load.return_value = plugin
return ep
def _mock_entry_points_factory(managed_names, supported_names):
"""Return a side_effect function for mocking entry_points() with controlled plugins."""
managed = [_make_mock_entry_point(n) for n in managed_names]
supported = [_make_mock_entry_point(n) for n in supported_names]
def _entry_points(group):
if group == 'awx_plugins.managed_credentials':
return managed
elif group == 'awx_plugins.managed_credentials.supported':
return supported
return []
return _entry_points
# --- Unit tests for load_credentials() registry behavior ---
def test_oidc_types_in_registry_when_flag_enabled(isolated_registry):
"""Test that OIDC credential types are added to the registry when flag is enabled."""
mock_eps = _mock_entry_points_factory(
managed_names=['ssh', 'vault'],
supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'],
)
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True):
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps):
load_credentials()
for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES:
assert ns in ManagedCredentialType.registry, f"{ns} should be in registry when flag is enabled"
assert 'ssh' in ManagedCredentialType.registry
assert 'vault' in ManagedCredentialType.registry
def test_oidc_types_not_in_registry_when_flag_disabled(isolated_registry):
"""Test that OIDC credential types are excluded from the registry when flag is disabled."""
mock_eps = _mock_entry_points_factory(
managed_names=['ssh', 'vault'],
supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'],
)
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False):
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps):
load_credentials()
for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES:
assert ns not in ManagedCredentialType.registry, f"{ns} should not be in registry when flag is disabled"
# Non-OIDC types should still be loaded
assert 'ssh' in ManagedCredentialType.registry
assert 'vault' in ManagedCredentialType.registry
def test_oidc_namespaces_constant():
"""Test that OIDC_CREDENTIAL_TYPE_NAMESPACES contains the expected namespaces."""
assert 'hashivault-kv-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES
assert 'hashivault-ssh-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES
assert len(OIDC_CREDENTIAL_TYPE_NAMESPACES) == 2
# --- Functional API tests ---
@pytest.mark.django_db
def test_oidc_types_loaded_when_flag_enabled(get, admin, reload_credentials_with_flag):
"""Test that OIDC credential types are visible in the API when flag is enabled."""
reload_credentials_with_flag(flag_enabled=True)
response = get(reverse('api:credential_type_list'), admin)
assert response.status_code == 200
namespaces = [ct['namespace'] for ct in response.data['results']]
assert 'hashivault-kv-oidc' in namespaces
assert 'hashivault-ssh-oidc' in namespaces
@pytest.mark.django_db
def test_oidc_types_not_loaded_when_flag_disabled(get, admin, reload_credentials_with_flag):
"""Test that OIDC credential types are not visible in the API when flag is disabled."""
reload_credentials_with_flag(flag_enabled=False)
response = get(reverse('api:credential_type_list'), admin)
assert response.status_code == 200
namespaces = [ct['namespace'] for ct in response.data['results']]
assert 'hashivault-kv-oidc' not in namespaces
assert 'hashivault-ssh-oidc' not in namespaces
# Verify they're also not in the database
assert not CredentialType.objects.filter(namespace='hashivault-kv-oidc').exists()
assert not CredentialType.objects.filter(namespace='hashivault-ssh-oidc').exists()

View File

@@ -1,3 +1,4 @@
from datetime import date
from unittest import mock from unittest import mock
import pytest import pytest
@@ -252,7 +253,7 @@ def test_user_verify_attribute_created(admin, get):
resp = get(reverse('api:user_detail', kwargs={'pk': admin.pk}), admin) resp = get(reverse('api:user_detail', kwargs={'pk': admin.pk}), admin)
assert resp.data['created'] == admin.date_joined assert resp.data['created'] == admin.date_joined
past = "2020-01-01T00:00:00Z" past = date(2020, 1, 1).isoformat()
for op, count in (('gt', 1), ('lt', 0)): for op, count in (('gt', 1), ('lt', 0)):
resp = get(reverse('api:user_list') + f'?created__{op}={past}', admin) resp = get(reverse('api:user_list') + f'?created__{op}={past}', admin)
assert resp.data['count'] == count assert resp.data['count'] == count

View File

@@ -48,7 +48,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
worker = CallbackBrokerWorker() worker = CallbackBrokerWorker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())] events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events} worker.buff = {InventoryUpdateEvent: events}
worker.flush(force=True) worker.flush()
assert worker.buff.get(InventoryUpdateEvent, []) == [] assert worker.buff.get(InventoryUpdateEvent, []) == []
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1 assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
@@ -61,7 +61,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
InventoryUpdateEvent(uuid=str(uuid4()), stdout='good2', **kwargs), InventoryUpdateEvent(uuid=str(uuid4()), stdout='good2', **kwargs),
] ]
worker.buff = {InventoryUpdateEvent: events.copy()} worker.buff = {InventoryUpdateEvent: events.copy()}
worker.flush(force=True) worker.flush()
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1 assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
assert InventoryUpdateEvent.objects.filter(uuid=events[1].uuid).count() == 0 assert InventoryUpdateEvent.objects.filter(uuid=events[1].uuid).count() == 0
assert InventoryUpdateEvent.objects.filter(uuid=events[2].uuid).count() == 1 assert InventoryUpdateEvent.objects.filter(uuid=events[2].uuid).count() == 1
@@ -71,7 +71,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
worker = CallbackBrokerWorker() worker = CallbackBrokerWorker()
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())] events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
worker.buff = {InventoryUpdateEvent: events.copy()} worker.buff = {InventoryUpdateEvent: events.copy()}
worker.flush(force=True) worker.flush()
# put current saved event in buffer (error case) # put current saved event in buffer (error case)
worker.buff = {InventoryUpdateEvent: [InventoryUpdateEvent.objects.get(uuid=events[0].uuid)]} worker.buff = {InventoryUpdateEvent: [InventoryUpdateEvent.objects.get(uuid=events[0].uuid)]}
@@ -113,7 +113,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
with mock.patch.object(InventoryUpdateEvent.objects, 'bulk_create', side_effect=ValueError): with mock.patch.object(InventoryUpdateEvent.objects, 'bulk_create', side_effect=ValueError):
with mock.patch.object(events[0], 'save', side_effect=ValueError): with mock.patch.object(events[0], 'save', side_effect=ValueError):
worker.flush(force=True) worker.flush()
assert "\x00" not in events[0].stdout assert "\x00" not in events[0].stdout

View File

@@ -1,17 +0,0 @@
import pytest
from awx.main.dispatch.config import get_dispatcherd_config
from awx.main.management.commands.dispatcherd import _hash_config
@pytest.mark.django_db
def test_dispatcherd_config_hash_is_stable(settings, monkeypatch):
monkeypatch.setenv('AWX_COMPONENT', 'dispatcher')
settings.CLUSTER_HOST_ID = 'test-node'
settings.JOB_EVENT_WORKERS = 1
settings.DISPATCHER_SCHEDULE = {}
config_one = get_dispatcherd_config(for_service=True)
config_two = get_dispatcherd_config(for_service=True)
assert _hash_config(config_one) == _hash_config(config_two)

View File

@@ -10,23 +10,9 @@ from django.test.utils import override_settings
@pytest.mark.django_db @pytest.mark.django_db
def test_multiple_hybrid_instances(): def test_multiple_instances():
for i in range(3):
Instance.objects.create(hostname=f'foo{i}', node_type='hybrid')
assert is_ha_environment()
@pytest.mark.django_db
def test_double_control_instances():
for i in range(2): for i in range(2):
Instance.objects.create(hostname=f'foo{i}', node_type='control') Instance.objects.create(hostname=f'foo{i}', node_type='hybrid')
assert is_ha_environment()
@pytest.mark.django_db
def test_mix_hybrid_control_instances():
Instance.objects.create(hostname='control_node', node_type='control')
Instance.objects.create(hostname='hybrid_node', node_type='hybrid')
assert is_ha_environment() assert is_ha_environment()

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from awx.main.models import JobTemplate, Job, JobHostSummary, WorkflowJob, Inventory, Host, Project, Organization from awx.main.models import JobTemplate, Job, JobHostSummary, WorkflowJob, Inventory, Project, Organization
@pytest.mark.django_db @pytest.mark.django_db
@@ -87,47 +87,3 @@ class TestSlicingModels:
unified_job = job_template.create_unified_job(job_slice_count=2) unified_job = job_template.create_unified_job(job_slice_count=2)
assert isinstance(unified_job, Job) assert isinstance(unified_job, Job)
@pytest.mark.django_db
class TestGetSourceHostsForConstructedInventory:
"""Tests for Job.get_source_hosts_for_constructed_inventory"""
def test_returns_source_hosts_via_instance_id(self):
"""Constructed hosts with instance_id pointing to source hosts are resolved correctly."""
org = Organization.objects.create(name='test-org')
inv_input = Inventory.objects.create(organization=org, name='input-inv')
source_host1 = inv_input.hosts.create(name='host1')
source_host2 = inv_input.hosts.create(name='host2')
inv_constructed = Inventory.objects.create(organization=org, name='constructed-inv', kind='constructed')
inv_constructed.input_inventories.add(inv_input)
Host.objects.create(inventory=inv_constructed, name='host1', instance_id=str(source_host1.id))
Host.objects.create(inventory=inv_constructed, name='host2', instance_id=str(source_host2.id))
job = Job.objects.create(name='test-job', inventory=inv_constructed)
result = job.get_source_hosts_for_constructed_inventory()
assert set(result.values_list('id', flat=True)) == {source_host1.id, source_host2.id}
def test_no_inventory_returns_empty(self):
"""A job with no inventory returns an empty queryset."""
job = Job.objects.create(name='test-job')
result = job.get_source_hosts_for_constructed_inventory()
assert result.count() == 0
def test_ignores_hosts_without_instance_id(self):
"""Hosts with empty instance_id are excluded from the result."""
org = Organization.objects.create(name='test-org')
inv_input = Inventory.objects.create(organization=org, name='input-inv')
source_host = inv_input.hosts.create(name='host1')
inv_constructed = Inventory.objects.create(organization=org, name='constructed-inv', kind='constructed')
inv_constructed.input_inventories.add(inv_input)
Host.objects.create(inventory=inv_constructed, name='host1', instance_id=str(source_host.id))
Host.objects.create(inventory=inv_constructed, name='host-no-ref', instance_id='')
job = Job.objects.create(name='test-job', inventory=inv_constructed)
result = job.get_source_hosts_for_constructed_inventory()
assert list(result.values_list('id', flat=True)) == [source_host.id]

View File

@@ -1,6 +1,5 @@
import itertools import itertools
import pytest import pytest
from uuid import uuid4
# CRUM # CRUM
from crum import impersonate from crum import impersonate
@@ -34,64 +33,6 @@ def test_soft_unique_together(post, project, admin_user):
assert 'combination already exists' in str(r.data) assert 'combination already exists' in str(r.data)
@pytest.mark.django_db
class TestJobCancel:
"""
Coverage for UnifiedJob.cancel, focused on interaction with dispatcherd objects.
Using mocks for the dispatcherd objects, because tests by default use a no-op broker.
"""
def test_cancel_sets_flag_and_clears_start_args(self, mocker):
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo', start_args='{"secret": "value"}')
job.websocket_emit_status = mocker.MagicMock()
assert job.can_cancel is True
assert job.cancel_flag is False
job.cancel()
job.refresh_from_db()
assert job.cancel_flag is True
assert job.start_args == ''
def test_cancel_sets_job_explanation(self, mocker):
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo')
job.websocket_emit_status = mocker.MagicMock()
job_explanation = 'giggity giggity'
job.cancel(job_explanation=job_explanation)
job.refresh_from_db()
assert job.job_explanation == job_explanation
def test_cancel_sends_control_message(self, mocker):
celery_task_id = str(uuid4())
job = Job.objects.create(status='running', name='foo-job', celery_task_id=celery_task_id, controller_node='foo')
job.websocket_emit_status = mocker.MagicMock()
control = mocker.MagicMock()
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
job.cancel()
get_control.assert_called_once_with(default_publish_channel='foo')
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
def test_cancel_refreshes_task_id_before_sending_control(self, mocker):
job = Job.objects.create(status='pending', name='foo-job', celery_task_id='', controller_node='bar')
job.websocket_emit_status = mocker.MagicMock()
celery_task_id = str(uuid4())
Job.objects.filter(pk=job.pk).update(status='running', celery_task_id=celery_task_id)
control = mocker.MagicMock()
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
refresh_spy = mocker.spy(job, 'refresh_from_db')
job.cancel()
refresh_spy.assert_called_once_with(fields=['celery_task_id', 'controller_node'])
get_control.assert_called_once_with(default_publish_channel='bar')
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
@pytest.mark.django_db @pytest.mark.django_db
class TestCreateUnifiedJob: class TestCreateUnifiedJob:
""" """

View File

@@ -1,274 +0,0 @@
# Generated by Claude Opus 4.6 (claude-opus-4-6)
#
# Test file for cancel + dependency chain behavior and workflow cancel propagation.
#
# These tests verify:
#
# 1. TaskManager.process_job_dep_failures() correctly distinguishes canceled vs
# failed dependencies in the job_explanation message.
#
# 2. TaskManager.process_pending_tasks() transitions pending jobs with
# cancel_flag=True directly to canceled status.
#
# 3. WorkflowManager + TaskManager together cancel all spawned jobs in a
# workflow and finalize the workflow as canceled.
import pytest
from unittest import mock
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
from awx.main.models import JobTemplate, ProjectUpdate, WorkflowApproval, WorkflowJobTemplate
from awx.main.models.workflow import WorkflowApprovalTemplate
from . import create_job
@pytest.fixture
def scm_on_launch_objects(job_template_factory):
"""Create a job template with a project configured for scm_update_on_launch."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
p = objects.project
p.scm_update_on_launch = True
p.scm_update_cache_timeout = 0
p.save(skip_update=True)
return objects
def _create_job_with_dependency(objects):
"""Create a pending job and run DependencyManager to produce its project update dependency.
Returns (job, project_update).
"""
j = create_job(objects.job_template, dependencies_processed=False)
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
DependencyManager().schedule()
assert j.dependent_jobs.count() == 1
pu = j.dependent_jobs.first()
assert isinstance(pu.get_real_instance(), ProjectUpdate)
return j, pu
@pytest.mark.django_db
class TestCanceledDependencyFailsBlockedJob:
"""When a dependency project update is canceled or failed, the task manager
should fail the blocked job via process_job_dep_failures."""
def test_canceled_dependency_fails_blocked_job(self, controlplane_instance_group, scm_on_launch_objects):
"""A canceled dependency causes the blocked job to be failed with
a 'Previous Task Canceled' explanation."""
j, pu = _create_job_with_dependency(scm_on_launch_objects)
ProjectUpdate.objects.filter(pk=pu.pk).update(status='canceled', cancel_flag=True)
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
TaskManager().schedule()
j.refresh_from_db()
assert j.status == 'failed'
assert 'Previous Task Canceled' in j.job_explanation
def test_failed_dependency_fails_blocked_job(self, controlplane_instance_group, scm_on_launch_objects):
"""A failed dependency causes the blocked job to be failed with
a 'Previous Task Failed' explanation."""
j, pu = _create_job_with_dependency(scm_on_launch_objects)
ProjectUpdate.objects.filter(pk=pu.pk).update(status='failed')
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
TaskManager().schedule()
j.refresh_from_db()
assert j.status == 'failed'
assert 'Previous Task Failed' in j.job_explanation
@pytest.mark.django_db
class TestTaskManagerCancelsPendingJobsWithCancelFlag:
"""When the task manager encounters pending jobs that have cancel_flag set,
it should transition them directly to canceled status."""
def test_pending_job_with_cancel_flag_is_canceled(self, controlplane_instance_group, job_template_factory):
"""A pending job with cancel_flag=True is transitioned to canceled
by the task manager without being started."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
j = create_job(objects.job_template)
j.cancel_flag = True
j.save(update_fields=['cancel_flag'])
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
TaskManager().schedule()
j.refresh_from_db()
assert j.status == 'canceled'
assert 'canceled before it started' in j.job_explanation
assert not mock_start.called
def test_pending_job_without_cancel_flag_is_not_canceled(self, controlplane_instance_group, job_template_factory):
"""A normal pending job without cancel_flag should not be canceled
by the task manager (sanity check)."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
j = create_job(objects.job_template)
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
TaskManager().schedule()
j.refresh_from_db()
assert j.status != 'canceled'
def test_multiple_pending_jobs_with_cancel_flag_bulk_canceled(self, controlplane_instance_group, job_template_factory):
"""Multiple pending jobs with cancel_flag=True are all transitioned
to canceled in a single task manager cycle."""
objects = job_template_factory(
'jt',
organization='org1',
project='proj',
inventory='inv',
credential='cred',
)
jt = objects.job_template
jt.allow_simultaneous = True
jt.save()
jobs = []
for _ in range(3):
j = create_job(jt)
j.cancel_flag = True
j.save(update_fields=['cancel_flag'])
jobs.append(j)
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
TaskManager().schedule()
for j in jobs:
j.refresh_from_db()
assert j.status == 'canceled', f"Job {j.id} should be canceled but is {j.status}"
assert 'canceled before it started' in j.job_explanation
assert not mock_start.called
@pytest.mark.django_db
class TestWorkflowCancelFinalizesWorkflow:
"""When a workflow job is canceled, the WorkflowManager cancels spawned child
jobs (setting cancel_flag), the TaskManager transitions those pending jobs to
canceled, and a final WorkflowManager pass finalizes the workflow as canceled."""
def test_cancel_workflow_with_parallel_nodes(self, inventory, project, controlplane_instance_group):
"""Create a workflow with parallel nodes, cancel it after one job is
running, and verify all jobs and the workflow reach canceled status."""
jt = JobTemplate.objects.create(allow_simultaneous=False, inventory=inventory, project=project, playbook='helloworld.yml')
wfjt = WorkflowJobTemplate.objects.create(name='test-cancel-wf')
for _ in range(4):
wfjt.workflow_nodes.create(unified_job_template=jt)
wj = wfjt.create_unified_job()
wj.signal_start()
# TaskManager transitions workflow job to running via start_task
TaskManager().schedule()
wj.refresh_from_db()
assert wj.status == 'running'
# WorkflowManager spawns jobs for all 4 nodes
WorkflowManager().schedule()
assert jt.jobs.count() == 4
# Simulate one job running (blocking the others via allow_simultaneous=False)
first_job = jt.jobs.order_by('created').first()
first_job.status = 'running'
first_job.celery_task_id = 'fake-task-id'
first_job.controller_node = 'test-node'
first_job.save(update_fields=['status', 'celery_task_id', 'controller_node'])
# Cancel the workflow
wj.cancel_flag = True
wj.save(update_fields=['cancel_flag'])
# WorkflowManager sees cancel_flag, calls cancel_node_jobs() which sets
# cancel_flag on all child jobs
with mock.patch('awx.main.models.unified_jobs.UnifiedJob.cancel_dispatcher_process'):
WorkflowManager().schedule()
# The running job won't actually stop in tests (no dispatcher), simulate it
first_job.status = 'canceled'
first_job.save(update_fields=['status'])
# TaskManager processes remaining pending jobs with cancel_flag set
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
DependencyManager().schedule()
TaskManager().schedule()
for job in jt.jobs.all():
job.refresh_from_db()
assert job.status == 'canceled', f"Job {job.id} should be canceled but is {job.status}"
assert not mock_start.called
# Final WorkflowManager pass finalizes the workflow
WorkflowManager().schedule()
wj.refresh_from_db()
assert wj.status == 'canceled'
def test_cancel_workflow_with_approval_node(self, controlplane_instance_group):
"""Create a workflow with a pending approval node and a downstream job
node. Cancel the workflow and verify the approval is directly canceled
by the WorkflowManager (since approvals are excluded from TaskManager),
the downstream node is marked do_not_run, and the workflow finalizes."""
approval_template = WorkflowApprovalTemplate.objects.create(name='test-approval', timeout=0)
wfjt = WorkflowJobTemplate.objects.create(name='test-cancel-approval-wf')
approval_node = wfjt.workflow_nodes.create(unified_job_template=approval_template)
# Add a downstream node (just another approval to keep it simple)
downstream_template = WorkflowApprovalTemplate.objects.create(name='test-downstream', timeout=0)
downstream_node = wfjt.workflow_nodes.create(unified_job_template=downstream_template)
approval_node.success_nodes.add(downstream_node)
wj = wfjt.create_unified_job()
wj.signal_start()
# TaskManager transitions workflow to running
TaskManager().schedule()
wj.refresh_from_db()
assert wj.status == 'running'
# WorkflowManager spawns the approval (root node only, downstream waits)
WorkflowManager().schedule()
assert WorkflowApproval.objects.filter(unified_job_node__workflow_job=wj).count() == 1
approval_job = WorkflowApproval.objects.get(unified_job_node__workflow_job=wj)
assert approval_job.status == 'pending'
# Cancel the workflow
wj.cancel_flag = True
wj.save(update_fields=['cancel_flag'])
# WorkflowManager should cancel the approval directly and mark
# the downstream node as do_not_run
WorkflowManager().schedule()
approval_job.refresh_from_db()
assert approval_job.status == 'canceled', f"Approval should be canceled directly by WorkflowManager but is {approval_job.status}"
# Downstream node should be marked do_not_run with no job spawned
downstream_wj_node = wj.workflow_nodes.get(unified_job_template=downstream_template)
assert downstream_wj_node.do_not_run is True
assert downstream_wj_node.job is None
# Workflow should finalize as canceled in the same pass
wj.refresh_from_db()
assert wj.status == 'canceled'

View File

@@ -1,223 +0,0 @@
"""Functional tests for start_fact_cache / finish_fact_cache.
These tests use real database objects (via pytest-django) and real files
on disk, but do not launch jobs or subprocesses. Fact files are written
by start_fact_cache and then manipulated to simulate ansible output
before calling finish_fact_cache.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
import json
import os
import time
from datetime import timedelta
import pytest
from django.utils.timezone import now
from awx.main.models import Host, Inventory
from awx.main.tasks.facts import start_fact_cache, finish_fact_cache
@pytest.fixture
def artifacts_dir(tmp_path):
d = tmp_path / 'artifacts'
d.mkdir()
return str(d)
@pytest.mark.django_db
class TestFinishFactCacheScoping:
"""finish_fact_cache must only update hosts matched by the provided queryset."""
def test_same_hostname_different_inventories(self, organization, artifacts_dir):
"""Two inventories share a hostname; only the targeted one should be updated.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
inv1 = Inventory.objects.create(organization=organization, name='scope-inv1')
inv2 = Inventory.objects.create(organization=organization, name='scope-inv2')
host1 = inv1.hosts.create(name='shared')
host2 = inv2.hosts.create(name='shared')
# Give both hosts initial facts
for h in (host1, host2):
h.ansible_facts = {'original': True}
h.ansible_facts_modified = now()
h.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
# start_fact_cache writes reference files for inv1's hosts
start_fact_cache(inv1.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv1.id)
# Simulate ansible writing new facts for 'shared'
fact_file = os.path.join(artifacts_dir, 'fact_cache', 'shared')
future = time.time() + 60
with open(fact_file, 'w') as f:
json.dump({'updated': True}, f)
os.utime(fact_file, (future, future))
# finish with inv1's hosts as the queryset
finish_fact_cache(inv1.hosts, artifacts_dir=artifacts_dir, inventory_id=inv1.id)
host1.refresh_from_db()
host2.refresh_from_db()
assert host1.ansible_facts == {'updated': True}
assert host2.ansible_facts == {'original': True}, 'Host in a different inventory was modified despite not being in the queryset'
@pytest.mark.django_db
class TestFinishFactCacheConcurrentProtection:
"""finish_fact_cache must not clear facts that a concurrent job updated."""
def test_no_clear_when_no_file_was_written(self, organization, artifacts_dir):
"""Host with no prior facts should not have facts cleared when file is missing.
Generated by Claude Opus 4.6 (claude-opus-4-6).
start_fact_cache records hosts_cached[host] = False for hosts with no
prior facts (no file written). finish_fact_cache should skip the clear
for these hosts because the missing file is expected, not a clear signal.
"""
inv = Inventory.objects.create(organization=organization, name='concurrent-inv')
host = inv.hosts.create(name='target')
job_created = now() - timedelta(minutes=5)
# start_fact_cache records host with False (no facts → no file written)
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
# Simulate a concurrent job updating this host's facts AFTER our job was created
host.ansible_facts = {'from_concurrent_job': True}
host.ansible_facts_modified = now()
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
# The fact file is missing because start_fact_cache never wrote one.
# finish_fact_cache should skip this host entirely.
finish_fact_cache(
inv.hosts,
artifacts_dir=artifacts_dir,
inventory_id=inv.id,
job_created=job_created,
)
host.refresh_from_db()
assert host.ansible_facts == {'from_concurrent_job': True}, 'Facts were cleared for a host that never had a fact file written'
def test_skip_clear_when_facts_modified_after_job_created(self, organization, artifacts_dir):
"""If a file was written and then deleted, but facts were concurrently updated, skip clear.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
inv = Inventory.objects.create(organization=organization, name='concurrent-written-inv')
host = inv.hosts.create(name='target')
old_time = now() - timedelta(hours=1)
host.ansible_facts = {'original': True}
host.ansible_facts_modified = old_time
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
job_created = now() - timedelta(minutes=5)
# start_fact_cache writes a file (host has facts → True in map)
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
# Remove the fact file (ansible didn't target this host via --limit)
os.remove(os.path.join(artifacts_dir, 'fact_cache', host.name))
# Simulate a concurrent job updating this host's facts AFTER our job was created
host.ansible_facts = {'from_concurrent_job': True}
host.ansible_facts_modified = now()
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
finish_fact_cache(
inv.hosts,
artifacts_dir=artifacts_dir,
inventory_id=inv.id,
job_created=job_created,
)
host.refresh_from_db()
assert host.ansible_facts == {'from_concurrent_job': True}, 'Facts set by a concurrent job were cleared despite ansible_facts_modified > job_created'
def test_clear_when_facts_predate_job(self, organization, artifacts_dir):
"""If facts predate the job, a missing file should still clear them.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
inv = Inventory.objects.create(organization=organization, name='clear-inv')
host = inv.hosts.create(name='stale')
old_time = now() - timedelta(hours=1)
host.ansible_facts = {'stale': True}
host.ansible_facts_modified = old_time
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
job_created = now() - timedelta(minutes=5)
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
# Remove the fact file to simulate ansible's clear_facts
os.remove(os.path.join(artifacts_dir, 'fact_cache', host.name))
finish_fact_cache(
inv.hosts,
artifacts_dir=artifacts_dir,
inventory_id=inv.id,
job_created=job_created,
)
host.refresh_from_db()
assert host.ansible_facts == {}, 'Stale facts should have been cleared when the fact file is missing ' 'and ansible_facts_modified predates job_created'
@pytest.mark.django_db
class TestConstructedInventoryFactCache:
"""finish_fact_cache with a constructed inventory queryset must target source hosts."""
def test_facts_resolve_to_source_host(self, organization, artifacts_dir):
"""Facts must be written to the source host, not the constructed copy.
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
from django.db.models.functions import Cast
inv_input = Inventory.objects.create(organization=organization, name='ci-input')
source_host = inv_input.hosts.create(name='webserver')
inv_constructed = Inventory.objects.create(organization=organization, name='ci-constructed', kind='constructed')
inv_constructed.input_inventories.add(inv_input)
constructed_host = Host.objects.create(
inventory=inv_constructed,
name='webserver',
instance_id=str(source_host.id),
)
# Build the same queryset that get_hosts_for_fact_cache uses
id_field = Host._meta.get_field('id')
source_qs = Host.objects.filter(id__in=inv_constructed.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field)))
# Give the source host initial facts so start_fact_cache writes a file
source_host.ansible_facts = {'role': 'web'}
source_host.ansible_facts_modified = now()
source_host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
start_fact_cache(source_qs, artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv_constructed.id)
# Simulate ansible writing updated facts
fact_file = os.path.join(artifacts_dir, 'fact_cache', 'webserver')
future = time.time() + 60
with open(fact_file, 'w') as f:
json.dump({'role': 'web', 'deployed': True}, f)
os.utime(fact_file, (future, future))
finish_fact_cache(source_qs, artifacts_dir=artifacts_dir, inventory_id=inv_constructed.id)
source_host.refresh_from_db()
constructed_host.refresh_from_db()
assert source_host.ansible_facts == {'role': 'web', 'deployed': True}
assert not constructed_host.ansible_facts, f'Facts were stored on the constructed host: {constructed_host.ansible_facts!r}'

View File

@@ -29,30 +29,3 @@ def test_cancel_flag_on_start(jt_linked, caplog):
job = Job.objects.get(id=job.id) job = Job.objects.get(id=job.id)
assert job.status == 'canceled' assert job.status == 'canceled'
@pytest.mark.django_db
def test_runjob_run_can_accept_waiting_status(jt_linked, mocker):
"""Test that RunJob.run() can accept a job in 'waiting' status and transition it to 'running'
before the pre_run_hook is called"""
job = jt_linked.create_unified_job()
job.status = 'waiting'
job.save()
status_at_pre_run = None
def capture_status(instance, private_data_dir):
nonlocal status_at_pre_run
instance.refresh_from_db()
status_at_pre_run = instance.status
mock_pre_run = mocker.patch.object(RunJob, 'pre_run_hook', side_effect=capture_status)
task = RunJob()
try:
task.run(job.id)
except Exception:
pass
mock_pre_run.assert_called_once()
assert status_at_pre_run == 'running'

View File

@@ -9,7 +9,7 @@ from unittest import mock
import pytest import pytest
from awx.main.tasks.system import CleanupImagesAndFiles, execution_node_health_check, inspect_established_receptor_connections, clear_setting_cache from awx.main.tasks.system import CleanupImagesAndFiles, execution_node_health_check, inspect_established_receptor_connections, clear_setting_cache
from awx.main.management.commands.dispatcherd import Command from awx.main.management.commands.run_dispatcher import Command
from awx.main.models import Instance, Job, ReceptorAddress, InstanceLink from awx.main.models import Instance, Job, ReceptorAddress, InstanceLink

View File

@@ -74,64 +74,47 @@ GLqbpJyX2r3p/Rmo6mLY71SqpA==
@pytest.mark.django_db @pytest.mark.django_db
def test_default_cred_types(): def test_default_cred_types():
expected = [ assert sorted(CredentialType.defaults.keys()) == sorted(
'aim', [
'aws', 'aim',
'aws_secretsmanager_credential', 'aws',
'azure_kv', 'aws_secretsmanager_credential',
'azure_rm', 'azure_kv',
'bitbucket_dc_token', 'azure_rm',
'centrify_vault_kv', 'bitbucket_dc_token',
'conjur', 'centrify_vault_kv',
'controller', 'conjur',
'galaxy_api_token', 'controller',
'gce', 'galaxy_api_token',
'github_token', 'gce',
'github_app_lookup', 'github_token',
'gitlab_token', 'github_app_lookup',
'gpg_public_key', 'gitlab_token',
'hashivault_kv', 'gpg_public_key',
'hashivault_ssh', 'hashivault_kv',
'hcp_terraform', 'hashivault_ssh',
'insights', 'hcp_terraform',
'kubernetes_bearer_token', 'insights',
'net', 'kubernetes_bearer_token',
'openstack', 'net',
'registry', 'openstack',
'rhv', 'registry',
'satellite6', 'rhv',
'scm', 'satellite6',
'ssh', 'scm',
'terraform', 'ssh',
'thycotic_dsv', 'terraform',
'thycotic_tss', 'thycotic_dsv',
'vault', 'thycotic_tss',
'vmware', 'vault',
] 'vmware',
assert sorted(CredentialType.defaults.keys()) == sorted(expected) ]
assert 'hashivault-kv-oidc' not in CredentialType.defaults )
assert 'hashivault-ssh-oidc' not in CredentialType.defaults
for type_ in CredentialType.defaults.values(): for type_ in CredentialType.defaults.values():
assert type_().managed is True assert type_().managed is True
@pytest.mark.django_db
def test_default_cred_types_with_oidc_enabled():
from django.test import override_settings
from awx.main.models.credential import load_credentials, ManagedCredentialType
original_registry = ManagedCredentialType.registry.copy()
try:
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True):
ManagedCredentialType.registry.clear()
load_credentials()
assert 'hashivault-kv-oidc' in CredentialType.defaults
assert 'hashivault-ssh-oidc' in CredentialType.defaults
finally:
ManagedCredentialType.registry = original_registry
@pytest.mark.django_db @pytest.mark.django_db
def test_credential_creation(organization_factory): def test_credential_creation(organization_factory):
org = organization_factory('test').organization org = organization_factory('test').organization

View File

@@ -5,7 +5,6 @@ import pytest
from awx.main.models import Job, WorkflowJob, Instance from awx.main.models import Job, WorkflowJob, Instance
from awx.main.dispatch import reaper from awx.main.dispatch import reaper
from awx.main.tasks import system
from dispatcherd.publish import task from dispatcherd.publish import task
''' '''
@@ -62,6 +61,11 @@ class TestJobReaper(object):
('running', '', '', None, False), # running, not assigned to the instance ('running', '', '', None, False), # running, not assigned to the instance
('running', 'awx', '', None, True), # running, has the instance as its execution_node ('running', 'awx', '', None, True), # running, has the instance as its execution_node
('running', '', 'awx', None, True), # running, has the instance as its controller_node ('running', '', 'awx', None, True), # running, has the instance as its controller_node
('waiting', '', '', None, False), # waiting, not assigned to the instance
('waiting', 'awx', '', None, False), # waiting, was edited less than a minute ago
('waiting', '', 'awx', None, False), # waiting, was edited less than a minute ago
('waiting', 'awx', '', yesterday, False), # waiting, managed by another node, ignore
('waiting', '', 'awx', yesterday, True), # waiting, assigned to the controller_node, stale
], ],
) )
def test_should_reap(self, status, fail, execution_node, controller_node, modified): def test_should_reap(self, status, fail, execution_node, controller_node, modified):
@@ -79,6 +83,7 @@ class TestJobReaper(object):
# (because .save() overwrites it to _now_) # (because .save() overwrites it to _now_)
Job.objects.filter(id=j.id).update(modified=modified) Job.objects.filter(id=j.id).update(modified=modified)
reaper.reap(i) reaper.reap(i)
reaper.reap_waiting(i)
job = Job.objects.first() job = Job.objects.first()
if fail: if fail:
assert job.status == 'failed' assert job.status == 'failed'
@@ -87,20 +92,6 @@ class TestJobReaper(object):
else: else:
assert job.status == status assert job.status == status
def test_waiting_job_sent_back_to_pending(self):
this_inst = Instance(hostname='awx')
this_inst.save()
lost_inst = Instance(hostname='lost', node_type=Instance.Types.EXECUTION, node_state=Instance.States.UNAVAILABLE)
lost_inst.save()
job = Job.objects.create(status='waiting', controller_node=lost_inst.hostname, execution_node='lost')
system._heartbeat_handle_lost_instances([lost_inst], this_inst)
job.refresh_from_db()
assert job.status == 'pending'
assert job.controller_node == ''
assert job.execution_node == ''
@pytest.mark.parametrize( @pytest.mark.parametrize(
'excluded_uuids, fail, started', 'excluded_uuids, fail, started',
[ [

View File

@@ -8,7 +8,6 @@ from awx.main.models import (
Instance, Instance,
Host, Host,
JobHostSummary, JobHostSummary,
Inventory,
InventoryUpdate, InventoryUpdate,
InventorySource, InventorySource,
Project, Project,
@@ -18,60 +17,14 @@ from awx.main.models import (
InstanceGroup, InstanceGroup,
Label, Label,
ExecutionEnvironment, ExecutionEnvironment,
Credential,
CredentialType,
CredentialInputSource,
Organization,
JobTemplate,
) )
from awx.main.tasks import jobs
from awx.main.tasks.system import cluster_node_heartbeat from awx.main.tasks.system import cluster_node_heartbeat
from awx.main.utils.db import bulk_update_sorted_by_id from awx.main.utils.db import bulk_update_sorted_by_id
from ansible_base.lib.testing.util import feature_flag_enabled, feature_flag_disabled
from django.db import OperationalError from django.db import OperationalError
from django.test.utils import override_settings from django.test.utils import override_settings
@pytest.fixture
def job_template_with_credentials():
"""
Factory fixture that creates a job template with specified credentials.
Usage:
job = job_template_with_credentials(ssh_cred, vault_cred)
"""
def _create_job_template(
*credentials, org_name='test-org', project_name='test-project', inventory_name='test-inventory', jt_name='test-jt', playbook='test.yml'
):
"""
Create a job template with the given credentials.
Args:
*credentials: Variable number of Credential objects to attach to the job template
org_name: Name for the organization
project_name: Name for the project
inventory_name: Name for the inventory
jt_name: Name for the job template
playbook: Playbook filename
Returns:
Job instance created from the job template
"""
org = Organization.objects.create(name=org_name)
proj = Project.objects.create(name=project_name, organization=org)
inv = Inventory.objects.create(name=inventory_name, organization=org)
jt = JobTemplate.objects.create(name=jt_name, project=proj, inventory=inv, playbook=playbook)
if credentials:
jt.credentials.add(*credentials)
return jt.create_unified_job()
return _create_job_template
@pytest.mark.django_db @pytest.mark.django_db
def test_orphan_unified_job_creation(instance, inventory): def test_orphan_unified_job_creation(instance, inventory):
job = Job.objects.create(job_template=None, inventory=inventory, name='hi world') job = Job.objects.create(job_template=None, inventory=inventory, name='hi world')
@@ -309,442 +262,3 @@ class TestLaunchConfig:
assert config.execution_environment assert config.execution_environment
# We just write the PK instead of trying to assign an item, that happens on the save # We just write the PK instead of trying to assign an item, that happens on the save
assert config.execution_environment_id == ee.id assert config.execution_environment_id == ee.id
@pytest.mark.django_db
def test_base_task_credentials_property(job_template_with_credentials):
"""Test that _credentials property caches credentials and doesn't re-query."""
task = jobs.RunJob()
# Create real credentials
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
# Create a job with credentials using fixture
job = job_template_with_credentials(ssh_cred, vault_cred)
task.instance = job
# First access should build credentials
result1 = task._credentials
assert len(result1) == 2
assert isinstance(result1, list)
# Second access should return cached value (we can verify by checking it's the same list object)
result2 = task._credentials
assert result2 is result1 # Same object reference
@pytest.mark.django_db
def test_run_job_machine_credential(job_template_with_credentials):
"""Test _machine_credential returns ssh credential from cache."""
task = jobs.RunJob()
# Create credentials
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
# Create a job using fixture
job = job_template_with_credentials(ssh_cred, vault_cred)
task.instance = job
# Set cached credentials
task._credentials = [ssh_cred, vault_cred]
# Get machine credential
result = task._machine_credential
assert result == ssh_cred
assert result.credential_type.kind == 'ssh'
@pytest.mark.django_db
def test_run_job_machine_credential_none(job_template_with_credentials):
"""Test _machine_credential returns None when no ssh credential exists."""
task = jobs.RunJob()
# Create only vault credential
vault_type = CredentialType.defaults['vault']()
vault_type.save()
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
job = job_template_with_credentials(vault_cred)
task.instance = job
# Set cached credentials
task._credentials = [vault_cred]
# Get machine credential
result = task._machine_credential
assert result is None
@pytest.mark.django_db
def test_run_job_vault_credentials(job_template_with_credentials):
"""Test _vault_credentials returns all vault credentials from cache."""
task = jobs.RunJob()
# Create credentials
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_cred1 = Credential.objects.create(credential_type=vault_type, name='vault-1')
vault_cred2 = Credential.objects.create(credential_type=vault_type, name='vault-2')
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
job = job_template_with_credentials(vault_cred1, ssh_cred, vault_cred2)
task.instance = job
# Set cached credentials
task._credentials = [vault_cred1, ssh_cred, vault_cred2]
# Get vault credentials
result = task._vault_credentials
assert len(result) == 2
assert vault_cred1 in result
assert vault_cred2 in result
assert ssh_cred not in result
@pytest.mark.django_db
def test_run_job_network_credentials(job_template_with_credentials):
"""Test _network_credentials returns all network credentials from cache."""
task = jobs.RunJob()
# Create credentials
net_type = CredentialType.defaults['net']()
net_type.save()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
net_cred = Credential.objects.create(credential_type=net_type, name='net-cred')
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
job = job_template_with_credentials(net_cred, ssh_cred)
task.instance = job
# Set cached credentials
task._credentials = [net_cred, ssh_cred]
# Get network credentials
result = task._network_credentials
assert len(result) == 1
assert result[0] == net_cred
@pytest.mark.django_db
def test_run_job_cloud_credentials(job_template_with_credentials):
"""Test _cloud_credentials returns all cloud credentials from cache."""
task = jobs.RunJob()
# Create credentials
aws_type = CredentialType.defaults['aws']()
aws_type.save()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
aws_cred = Credential.objects.create(credential_type=aws_type, name='aws-cred')
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
job = job_template_with_credentials(aws_cred, ssh_cred)
task.instance = job
# Set cached credentials
task._credentials = [aws_cred, ssh_cred]
# Get cloud credentials
result = task._cloud_credentials
assert len(result) == 1
assert result[0] == aws_cred
@pytest.mark.django_db
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
def test_populate_workload_identity_tokens_with_flag_enabled(job_template_with_credentials, mocker):
"""Test populate_workload_identity_tokens sets context when flag is enabled."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create credential types
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
# Create a workload identity credential type
hashivault_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_type.save()
# Create credentials
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'jwt_aud': 'https://vault.example.com'})
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
# Create input source linking source credential to target credential
input_source = CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
)
# Create a job using fixture
job = job_template_with_credentials(target_cred, ssh_cred)
task.instance = job
# Override cached_property so the loop uses these exact Python objects
task._credentials = [target_cred, ssh_cred]
# Mock only the HTTP response from the Gateway workload identity endpoint
mock_response = mocker.Mock(status_code=200)
mock_response.json.return_value = {'jwt': 'eyJ.test.jwt'}
mock_request = mocker.patch('requests.request', return_value=mock_response, autospec=True)
task.populate_workload_identity_tokens()
# Verify the HTTP call was made to the correct endpoint
mock_request.assert_called_once()
call_kwargs = mock_request.call_args.kwargs
assert call_kwargs['method'] == 'POST'
assert '/api/gateway/v1/workload_identity_tokens' in call_kwargs['url']
# Verify context was set on the credential, keyed by input source PK
assert input_source.pk in target_cred.context
assert target_cred.context[input_source.pk]['workload_identity_token'] == 'eyJ.test.jwt'
@pytest.mark.django_db
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
def test_populate_workload_identity_tokens_passes_workload_ttl_from_job_timeout(job_template_with_credentials, mocker):
"""Test populate_workload_identity_tokens passes workload_ttl_seconds from get_instance_timeout to the client."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
hashivault_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'jwt_aud': 'https://vault.example.com'})
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
)
job = job_template_with_credentials(target_cred, ssh_cred)
job.timeout = 3600
job.save()
task.instance = job
task._credentials = [target_cred, ssh_cred]
mock_response = mocker.Mock(status_code=200)
mock_response.json.return_value = {'jwt': 'eyJ.test.jwt'}
mock_request = mocker.patch('requests.request', return_value=mock_response, autospec=True)
task.populate_workload_identity_tokens()
call_kwargs = mock_request.call_args.kwargs
assert call_kwargs['method'] == 'POST'
json_body = call_kwargs.get('json', {})
assert json_body.get('workload_ttl_seconds') == 3600
@pytest.mark.django_db
def test_populate_workload_identity_tokens_with_flag_disabled(job_template_with_credentials):
"""Test populate_workload_identity_tokens sets error status when flag is disabled."""
with feature_flag_disabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create credential types
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
# Create a workload identity credential type
hashivault_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_type.save()
# Create credentials
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source')
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
# Create input source linking source credential to target credential
# Note: Creates the relationship that will trigger the feature flag check
CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
)
# Create a job using fixture
job = job_template_with_credentials(target_cred)
task.instance = job
# Set cached credentials
task._credentials = [target_cred]
task.populate_workload_identity_tokens()
# Verify job status was set to error
job.refresh_from_db()
assert job.status == 'error'
assert 'FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED' in job.job_explanation
assert 'vault-source' in job.job_explanation
@pytest.mark.django_db
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
def test_populate_workload_identity_tokens_multiple_input_sources_per_credential(job_template_with_credentials, mocker):
"""Test that a single credential with two input sources from different workload identity
credential types gets a separate JWT token for each input source, keyed by input source PK."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create credential types
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
# Create two different workload identity credential types
hashivault_kv_type = CredentialType(
name='HashiCorp Vault Secret Lookup (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_kv_type.save()
hashivault_ssh_type = CredentialType(
name='HashiCorp Vault Signed SSH (OIDC)',
kind='cloud',
managed=False,
inputs={
'fields': [
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
]
},
)
hashivault_ssh_type.save()
# Create source credentials with different audiences
source_cred_kv = Credential.objects.create(
credential_type=hashivault_kv_type, name='vault-kv-source', inputs={'jwt_aud': 'https://vault-kv.example.com'}
)
source_cred_ssh = Credential.objects.create(
credential_type=hashivault_ssh_type, name='vault-ssh-source', inputs={'jwt_aud': 'https://vault-ssh.example.com'}
)
# Create target credential that uses both sources for different fields
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
# Create two input sources on the same target credential, each for a different field
input_source_password = CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred_kv, input_field_name='password', metadata={'path': 'secret/data/password'}
)
input_source_ssh_key = CredentialInputSource.objects.create(
target_credential=target_cred, source_credential=source_cred_ssh, input_field_name='ssh_key_data', metadata={'path': 'secret/data/ssh_key'}
)
# Create a job using fixture
job = job_template_with_credentials(target_cred)
task.instance = job
# Override cached_property so the loop uses this exact Python object
task._credentials = [target_cred]
# Mock HTTP responses - return different JWTs for each call
response_kv = mocker.Mock(status_code=200)
response_kv.json.return_value = {'jwt': 'eyJ.kv.jwt'}
response_ssh = mocker.Mock(status_code=200)
response_ssh.json.return_value = {'jwt': 'eyJ.ssh.jwt'}
mock_request = mocker.patch('requests.request', side_effect=[response_kv, response_ssh], autospec=True)
task.populate_workload_identity_tokens()
# Verify two separate HTTP calls were made (one per input source)
assert mock_request.call_count == 2
# Verify each call used the correct audience from its source credential
audiences_requested = {call.kwargs.get('json', {}).get('audience', '') for call in mock_request.call_args_list}
assert 'https://vault-kv.example.com' in audiences_requested
assert 'https://vault-ssh.example.com' in audiences_requested
# Verify context on the target credential has both tokens, keyed by input source PK
assert input_source_password.pk in target_cred.context
assert input_source_ssh_key.pk in target_cred.context
assert target_cred.context[input_source_password.pk]['workload_identity_token'] == 'eyJ.kv.jwt'
assert target_cred.context[input_source_ssh_key.pk]['workload_identity_token'] == 'eyJ.ssh.jwt'
@pytest.mark.django_db
def test_populate_workload_identity_tokens_without_workload_identity_credentials(job_template_with_credentials, mocker):
"""Test populate_workload_identity_tokens does nothing when no workload identity credentials."""
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
task = jobs.RunJob()
# Create only standard credentials (no workload identity)
ssh_type = CredentialType.defaults['ssh']()
ssh_type.save()
vault_type = CredentialType.defaults['vault']()
vault_type.save()
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
# Create a job using fixture
job = job_template_with_credentials(ssh_cred, vault_cred)
task.instance = job
# Set cached credentials
task._credentials = [ssh_cred, vault_cred]
mocker.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 123}, autospec=True)
task.populate_workload_identity_tokens()
# Verify no context was set
assert not hasattr(ssh_cred, '_context') or ssh_cred.context == {}
assert not hasattr(vault_cred, '_context') or vault_cred.context == {}

View File

@@ -173,54 +173,3 @@ class TestMigrationSmoke:
assert Role.objects.filter( assert Role.objects.filter(
singleton_name='system_administrator', role_field='system_administrator' singleton_name='system_administrator', role_field='system_administrator'
).exists(), "expected to find a system_administrator singleton role" ).exists(), "expected to find a system_administrator singleton role"
@pytest.mark.django_db
class TestGithubAppBug:
"""
Tests that `awx-manage createsuperuser` runs successfully after
the `github_app` CredentialType kind is updated to `github_app_lookup`
via the migration.
"""
def test_after_github_app_kind_migration(self, migrator):
"""
Verifies that `createsuperuser` does not raise a KeyError
after the 0204_squashed_deletions migration (which includes
the `update_github_app_kind` logic) is applied.
"""
# 1. Apply migrations up to the point *before* the 0204_squashed_deletions migration.
# This simulates the state where the problematic CredentialType might exist.
# We use 0203_remove_team_of_teams as the direct predecessor.
old_state = migrator.apply_tested_migration(('main', '0203_remove_team_of_teams'))
# Get the CredentialType model from the historical state.
CredentialType = old_state.apps.get_model('main', 'CredentialType')
# Create a CredentialType with the old, problematic 'namespace' value
CredentialType.objects.create(
name='Legacy GitHub App Credential',
kind='external',
namespace='github_app', # The namespace that causes the KeyError in the registry lookup
managed=True,
created=now(),
modified=now(),
)
# Apply the migration that includes the fix (0204_squashed_deletions).
new_state = migrator.apply_tested_migration(('main', '0204_squashed_deletions'))
# Verify that the CredentialType with the old 'kind' no longer exists
# and the 'kind' has been updated to the new value.
CredentialType = new_state.apps.get_model('main', 'CredentialType') # Get CredentialType model from the new state
# Assertion 1: The CredentialType with the old 'github_app' kind should no longer exist.
assert not CredentialType.objects.filter(
namespace='github_app'
).exists(), "CredentialType with old 'github_app' kind should no longer exist after migration."
# Assertion 2: The CredentialType should now exist with the new 'github_app_lookup' kind
# and retain its original name.
assert CredentialType.objects.filter(
namespace='github_app_lookup', name='Legacy GitHub App Credential'
).exists(), "CredentialType should be updated to 'github_app_lookup' and retain its name."

View File

@@ -18,14 +18,13 @@ from awx.main.tests.functional.conftest import * # noqa
from awx.main.tests.conftest import load_all_credentials # noqa: F401; pylint: disable=unused-import from awx.main.tests.conftest import load_all_credentials # noqa: F401; pylint: disable=unused-import
from awx.main.tests import data from awx.main.tests import data
from awx.main.models import Project, JobTemplate, Organization, Inventory, WorkflowJob, UnifiedJob from awx.main.models import Project, JobTemplate, Organization, Inventory
from awx.main.tasks.system import clear_setting_cache from awx.main.tasks.system import clear_setting_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PROJ_DATA = os.path.join(os.path.dirname(data.__file__), 'projects') PROJ_DATA = os.path.join(os.path.dirname(data.__file__), 'projects')
COLL_DATA = os.path.join(os.path.dirname(data.__file__), 'collections')
def _copy_folders(source_path, dest_path, clear=False): def _copy_folders(source_path, dest_path, clear=False):
@@ -57,7 +56,6 @@ def live_tmp_folder():
shutil.rmtree(path) shutil.rmtree(path)
os.mkdir(path) os.mkdir(path)
_copy_folders(PROJ_DATA, path) _copy_folders(PROJ_DATA, path)
_copy_folders(COLL_DATA, path)
for dirname in os.listdir(path): for dirname in os.listdir(path):
source_dir = os.path.join(path, dirname) source_dir = os.path.join(path, dirname)
subprocess.run(GIT_COMMANDS, cwd=source_dir, shell=True) subprocess.run(GIT_COMMANDS, cwd=source_dir, shell=True)
@@ -71,7 +69,7 @@ def live_tmp_folder():
settings._awx_conf_memoizedcache.clear() settings._awx_conf_memoizedcache.clear()
# cache is cleared in test environment, but need to clear in test environment # cache is cleared in test environment, but need to clear in test environment
clear_setting_cache.delay(['AWX_ISOLATION_SHOW_PATHS']) clear_setting_cache.delay(['AWX_ISOLATION_SHOW_PATHS'])
time.sleep(5.0) # for _awx_conf_memoizedcache to expire on all workers time.sleep(0.2) # allow task to finish, we have no real metric to know
else: else:
logger.info(f'Believed that {path} is already in settings.AWX_ISOLATION_SHOW_PATHS: {settings.AWX_ISOLATION_SHOW_PATHS}') logger.info(f'Believed that {path} is already in settings.AWX_ISOLATION_SHOW_PATHS: {settings.AWX_ISOLATION_SHOW_PATHS}')
return path return path
@@ -102,21 +100,6 @@ def wait_for_events(uj, timeout=2):
def unified_job_stdout(uj): def unified_job_stdout(uj):
if type(uj) is UnifiedJob:
uj = uj.get_real_instance()
if isinstance(uj, WorkflowJob):
outputs = []
for node in uj.workflow_job_nodes.all().select_related('job').order_by('id'):
if node.job is None:
continue
outputs.append(
'workflow node {node_id} job {job_id} output:\n{output}'.format(
node_id=node.id,
job_id=node.job.id,
output=unified_job_stdout(node.job),
)
)
return '\n'.join(outputs)
wait_for_events(uj) wait_for_events(uj)
return '\n'.join([event.stdout for event in uj.get_event_queryset().order_by('created')]) return '\n'.join([event.stdout for event in uj.get_event_queryset().order_by('created')])

View File

@@ -1,351 +0,0 @@
"""Tests for concurrent fact caching with --limit.
Reproduces bugs where concurrent jobs targeting different hosts via --limit
incorrectly modify (clear or revert) facts for hosts outside their limit.
Customer report: concurrent jobs on the same job template with different limits
cause facts set by an earlier-finishing job to be rolled back when the
later-finishing job completes.
See: https://github.com/jritter/concurrent-aap-fact-caching
Generated by Claude Opus 4.6 (claude-opus-4-6).
"""
import logging
import pytest
from django.utils.timezone import now
from awx.api.versioning import reverse
from awx.main.models import Inventory, JobTemplate
from awx.main.tests.live.tests.conftest import wait_for_job, wait_to_leave_status
logger = logging.getLogger(__name__)
@pytest.fixture
def concurrent_facts_inventory(default_org):
"""Inventory with two hosts for concurrent fact cache testing."""
inv_name = 'test_concurrent_fact_cache'
Inventory.objects.filter(organization=default_org, name=inv_name).delete()
inv = Inventory.objects.create(organization=default_org, name=inv_name)
inv.hosts.create(name='cc_host_0')
inv.hosts.create(name='cc_host_1')
return inv
@pytest.fixture
def concurrent_facts_jt(concurrent_facts_inventory, live_tmp_folder, post, admin, project_factory):
"""Job template configured for concurrent fact-cached runs."""
proj = project_factory(scm_url=f'file://{live_tmp_folder}/facts')
if proj.current_job:
wait_for_job(proj.current_job)
assert 'gather_slow.yml' in proj.playbooks, f'gather_slow.yml not in {proj.playbooks}'
jt_name = 'test_concurrent_fact_cache JT'
existing_jt = JobTemplate.objects.filter(name=jt_name).first()
if existing_jt:
existing_jt.delete()
result = post(
reverse('api:job_template_list'),
{
'name': jt_name,
'project': proj.id,
'playbook': 'gather_slow.yml',
'inventory': concurrent_facts_inventory.id,
'use_fact_cache': True,
'allow_simultaneous': True,
},
admin,
expect=201,
)
return JobTemplate.objects.get(id=result.data['id'])
def test_concurrent_limit_does_not_clear_facts(concurrent_facts_inventory, concurrent_facts_jt):
"""Concurrent jobs with different --limit must not clear each other's facts.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Scenario:
- Inventory has cc_host_0 and cc_host_1, neither has prior facts
- Job A runs gather_slow.yml with limit=cc_host_0
- While Job A is still running (sleeping), Job B launches with limit=cc_host_1
- Both jobs set cacheable facts, but only for their respective limited host
- After both complete, BOTH hosts should have populated facts
The bug: get_hosts_for_fact_cache() returns ALL inventory hosts regardless
of --limit. start_fact_cache records them all in hosts_cached but writes
no fact files (no prior facts). When the later-finishing job runs
finish_fact_cache, it sees a missing fact file for the other job's host
and clears that host's facts.
"""
inv = concurrent_facts_inventory
jt = concurrent_facts_jt
# Launch Job A targeting cc_host_0
job_a = jt.create_unified_job()
job_a.limit = 'cc_host_0'
job_a.save(update_fields=['limit'])
job_a.signal_start()
# Wait for Job A to reach running (it will sleep inside the playbook)
wait_to_leave_status(job_a, 'pending')
wait_to_leave_status(job_a, 'waiting')
logger.info(f'Job A (id={job_a.id}) is now running with limit=cc_host_0')
# Launch Job B targeting cc_host_1 while Job A is still running
job_b = jt.create_unified_job()
job_b.limit = 'cc_host_1'
job_b.save(update_fields=['limit'])
job_b.signal_start()
# Verify that Job A is still running when Job B starts,
# otherwise the overlap that triggers the bug did not happen.
wait_to_leave_status(job_b, 'pending')
wait_to_leave_status(job_b, 'waiting')
job_a.refresh_from_db()
if job_a.status != 'running':
pytest.skip('Job A finished before Job B started running; overlap did not occur')
logger.info(f'Job B (id={job_b.id}) is now running with limit=cc_host_1 (concurrent with Job A)')
# Wait for both to complete
wait_for_job(job_a)
wait_for_job(job_b)
# Verify facts survived concurrent execution
host_0 = inv.hosts.get(name='cc_host_0')
host_1 = inv.hosts.get(name='cc_host_1')
# sanity
job_a.refresh_from_db()
job_b.refresh_from_db()
assert job_a.limit == "cc_host_0"
assert job_b.limit == "cc_host_1"
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
assert discovered_foos == ['bar'] * 2, f'Unexpected facts on cc_host_0 or _1: {discovered_foos} after job a,b {job_a.id}, {job_b.id}'
def test_concurrent_limit_does_not_revert_facts(live_tmp_folder, run_job_from_playbook, concurrent_facts_inventory):
"""Concurrent jobs must not revert facts that a prior concurrent job just set.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Scenario:
- First, populate both hosts with initial facts (foo=bar) via a
non-concurrent gather run
- Then run two concurrent jobs with different limits, each setting
a new value (foo=bar_v2 via extra_vars)
- After both complete, BOTH hosts should have foo=bar_v2
The bug: start_fact_cache writes the OLD facts (foo=bar) into each job's
artifact dir for ALL hosts. If ansible's cache plugin rewrites a non-limited
host's fact file with the stale content (updating the mtime), finish_fact_cache
treats it as a legitimate update and overwrites the DB with old values.
"""
# --- Seed both hosts with initial facts via a non-concurrent run ---
inv = concurrent_facts_inventory
scm_url = f'file://{live_tmp_folder}/facts'
res = run_job_from_playbook(
'seed_facts_for_revert_test',
'gather_slow.yml',
scm_url=scm_url,
jt_params={'use_fact_cache': True, 'allow_simultaneous': True, 'inventory': inv.id},
)
for host in inv.hosts.all():
assert host.ansible_facts.get('foo') == 'bar', f'Seed run failed to set facts on {host.name}: {host.ansible_facts}'
job = res['job']
wait_for_job(job)
# sanity, jobs should be set up to both have facts with just bar
host_0 = inv.hosts.get(name='cc_host_0')
host_1 = inv.hosts.get(name='cc_host_1')
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
assert discovered_foos == ['bar'] * 2, f'Facts did not get expected initial values: {discovered_foos}'
jt = job.job_template
assert jt.allow_simultaneous is True
assert jt.use_fact_cache is True
# Sanity assertion, sometimes this would give problems from the Django rel cache
assert jt.project
# --- Run two concurrent jobs that write a new value ---
# Update the JT to pass extra_vars that change the fact value
jt.extra_vars = '{"extra_value": "_v2"}'
jt.save(update_fields=['extra_vars'])
job_a = jt.create_unified_job()
job_a.limit = 'cc_host_0'
job_a.save(update_fields=['limit'])
job_a.signal_start()
wait_to_leave_status(job_a, 'pending')
wait_to_leave_status(job_a, 'waiting')
job_b = jt.create_unified_job()
job_b.limit = 'cc_host_1'
job_b.save(update_fields=['limit'])
job_b.signal_start()
wait_to_leave_status(job_b, 'pending')
wait_to_leave_status(job_b, 'waiting')
job_a.refresh_from_db()
if job_a.status != 'running':
pytest.skip('Job A finished before Job B started running; overlap did not occur')
wait_for_job(job_a)
wait_for_job(job_b)
host_0 = inv.hosts.get(name='cc_host_0')
host_1 = inv.hosts.get(name='cc_host_1')
# Both hosts should have the UPDATED value, not the old seed value
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
assert discovered_foos == ['bar_v2'] * 2, f'Facts were reverted to stale values by concurrent job cc_host_0 or cc_host_1: {discovered_foos}'
def test_fact_cache_scoped_to_inventory(live_tmp_folder, default_org, run_job_from_playbook):
"""finish_fact_cache must not modify hosts in other inventories.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Bug: finish_fact_cache queries Host.objects.filter(name__in=host_names)
without an inventory_id filter, so hosts with the same name in different
inventories get their facts cross-contaminated.
"""
shared_name = 'scope_shared_host'
# Prepare for test by deleting junk from last run
for inv_name in ('test_fact_scope_inv1', 'test_fact_scope_inv2'):
inv = Inventory.objects.filter(name=inv_name).first()
if inv:
inv.delete()
inv1 = Inventory.objects.create(organization=default_org, name='test_fact_scope_inv1')
inv1.hosts.create(name=shared_name)
inv2 = Inventory.objects.create(organization=default_org, name='test_fact_scope_inv2')
host2 = inv2.hosts.create(name=shared_name)
# Give inv2's host distinct facts that should not be touched
original_facts = {'source': 'inventory_2', 'untouched': True}
host2.ansible_facts = original_facts
host2.ansible_facts_modified = now()
host2.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
# Run a fact-gathering job against inv1 only
run_job_from_playbook(
'test_fact_scope',
'gather.yml',
scm_url=f'file://{live_tmp_folder}/facts',
jt_params={'use_fact_cache': True, 'inventory': inv1.id},
)
# inv1's host should have facts
host1 = inv1.hosts.get(name=shared_name)
assert host1.ansible_facts, f'inv1 host should have facts after gather: {host1.ansible_facts}'
# inv2's host must NOT have been touched
host2.refresh_from_db()
assert host2.ansible_facts == original_facts, (
f'Host in a different inventory was modified by a fact cache operation '
f'on another inventory sharing the same hostname. '
f'Expected {original_facts!r}, got {host2.ansible_facts!r}'
)
def test_constructed_inventory_facts_saved_to_source_host(live_tmp_folder, default_org, run_job_from_playbook):
"""Facts from a constructed inventory job must be saved to the source host.
Generated by Claude Opus 4.6 (claude-opus-4-6).
Constructed inventories contain hosts that are references (via instance_id)
to 'real' hosts in input inventories. start_fact_cache correctly resolves
source hosts via get_hosts_for_fact_cache(), but finish_fact_cache must also
write facts back to the source hosts, not the constructed inventory's copies.
Scenario:
- Two input inventories each have a host named 'ci_shared_host'
- A constructed inventory uses both as inputs
- The inventory sync picks one source host (via instance_id) for the
constructed host — which one depends on input processing order
- Both source hosts start with distinct pre-existing facts
- A fact-gathering job runs against the constructed inventory
- After completion, the targeted source host should have the job's facts
- The OTHER source host must retain its original facts untouched
- The constructed host itself must NOT have facts stored on it
(constructed hosts are transient — recreated on each inventory sync)
"""
shared_name = 'ci_shared_host'
# Cleanup from prior runs
for inv_name in ('test_ci_facts_input1', 'test_ci_facts_input2', 'test_ci_facts_constructed'):
Inventory.objects.filter(name=inv_name).delete()
# --- Create two input inventories, each with an identically-named host ---
inv1 = Inventory.objects.create(organization=default_org, name='test_ci_facts_input1')
source_host1 = inv1.hosts.create(name=shared_name)
inv2 = Inventory.objects.create(organization=default_org, name='test_ci_facts_input2')
source_host2 = inv2.hosts.create(name=shared_name)
# Give both hosts distinct pre-existing facts so we can detect cross-contamination
host1_original_facts = {'source': 'inventory_1'}
source_host1.ansible_facts = host1_original_facts
source_host1.ansible_facts_modified = now()
source_host1.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
host2_original_facts = {'source': 'inventory_2'}
source_host2.ansible_facts = host2_original_facts
source_host2.ansible_facts_modified = now()
source_host2.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
source_hosts_by_id = {source_host1.id: source_host1, source_host2.id: source_host2}
original_facts_by_id = {source_host1.id: host1_original_facts, source_host2.id: host2_original_facts}
# --- Create constructed inventory (sync will create hosts from inputs) ---
constructed_inv = Inventory.objects.create(
organization=default_org,
name='test_ci_facts_constructed',
kind='constructed',
)
constructed_inv.input_inventories.add(inv1)
constructed_inv.input_inventories.add(inv2)
# --- Run a fact-gathering job against the constructed inventory ---
# The job launch triggers an inventory sync which creates the constructed
# host with instance_id pointing to one of the source hosts.
scm_url = f'file://{live_tmp_folder}/facts'
run_job_from_playbook(
'test_ci_facts',
'gather.yml',
scm_url=scm_url,
jt_params={'use_fact_cache': True, 'inventory': constructed_inv.id},
)
# --- Determine which source host the constructed host points to ---
constructed_host = constructed_inv.hosts.get(name=shared_name)
target_id = int(constructed_host.instance_id)
other_id = (set(source_hosts_by_id.keys()) - {target_id}).pop()
target_host = source_hosts_by_id[target_id]
other_host = source_hosts_by_id[other_id]
target_host.refresh_from_db()
other_host.refresh_from_db()
constructed_host.refresh_from_db()
actual = [target_host.ansible_facts.get('foo'), other_host.ansible_facts, constructed_host.ansible_facts]
expected = ['bar', original_facts_by_id[other_id], {}]
assert actual == expected, (
f'Constructed inventory fact cache wrote to wrong host(s). '
f'target source host (id={target_id}) foo={actual[0]!r}, '
f'other source host (id={other_id}) facts={actual[1]!r}, '
f'constructed host facts={actual[2]!r}; expected {expected!r}'
)

View File

@@ -1,347 +0,0 @@
"""
Integration tests for external query file functionality (AAP-58470).
Tests verify the end-to-end external query file workflow for indirect node
counting using real AWX job execution. A fixture-created vendor collection
at /var/lib/awx/vendor_collections/ provides external query files, simulating
what the build-time (AAP-58426) and deployment (AAP-58557) integrations will
provide once available.
Test data:
- Collection 'demo.external' at various versions (no embedded query)
- External query files in mock redhat.indirect_accounting collection
"""
import os
import shutil
import time
import yaml
import pytest
from flags.state import enable_flag, disable_flag, flag_enabled
from awx.main.tests.live.tests.conftest import wait_for_events, unified_job_stdout
from awx.main.tasks.host_indirect import save_indirect_host_entries
from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit
from awx.main.models.event_query import EventQuery
from awx.main.models import Job
# --- Constants ---
EXTERNAL_QUERY_JQ = '{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {device_type: .device_type}}'
EXTERNAL_QUERY_CONTENT = yaml.dump(
{'demo.external.example': {'query': EXTERNAL_QUERY_JQ}},
default_flow_style=False,
)
# For precedence test: different jq (no device_type in facts) so we can detect which query was used
EXTERNAL_QUERY_FOR_DEMO_QUERY_JQ = '{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {}}'
EXTERNAL_QUERY_FOR_DEMO_QUERY_CONTENT = yaml.dump(
{'demo.query.example': {'query': EXTERNAL_QUERY_FOR_DEMO_QUERY_JQ}},
default_flow_style=False,
)
VENDOR_COLLECTIONS_BASE = '/var/lib/awx/vendor_collections'
# --- Fixtures ---
@pytest.fixture
def enable_indirect_host_counting():
"""Enable FEATURE_INDIRECT_NODE_COUNTING_ENABLED flag for the test.
Only creates a FlagState DB record if the flag isn't already enabled
(e.g. via development_defaults.py), to avoid UniqueViolation errors
and to avoid leaking state to other tests.
"""
flag_name = "FEATURE_INDIRECT_NODE_COUNTING_ENABLED"
was_enabled = flag_enabled(flag_name)
if not was_enabled:
enable_flag(flag_name)
yield
if not was_enabled:
disable_flag(flag_name)
@pytest.fixture
def vendor_collections_dir():
"""Set up mock redhat.indirect_accounting collection at /var/lib/awx/vendor_collections/.
Creates the collection structure with external query files:
- demo.external.1.0.0.yml (exact match for v1.0.0)
- demo.external.1.1.0.yml (fallback target for v1.5.0)
- demo.query.0.0.1.yml (for precedence test with embedded-query collection)
"""
base = os.path.join(VENDOR_COLLECTIONS_BASE, 'ansible_collections', 'redhat', 'indirect_accounting')
queries_path = os.path.join(base, 'extensions', 'audit', 'external_queries')
meta_path = os.path.join(base, 'meta')
os.makedirs(queries_path, exist_ok=True)
os.makedirs(meta_path, exist_ok=True)
# galaxy.yml for valid collection structure
with open(os.path.join(base, 'galaxy.yml'), 'w') as f:
yaml.dump(
{
'namespace': 'redhat',
'name': 'indirect_accounting',
'version': '1.0.0',
'description': 'Test fixture for external query integration tests',
'authors': ['AWX Tests'],
'dependencies': {},
},
f,
)
# meta/runtime.yml
with open(os.path.join(meta_path, 'runtime.yml'), 'w') as f:
yaml.dump({'requires_ansible': '>=2.15.0'}, f)
# External query files for demo.external collection
for version in ('1.0.0', '1.1.0'):
with open(os.path.join(queries_path, f'demo.external.{version}.yml'), 'w') as f:
f.write(EXTERNAL_QUERY_CONTENT)
# External query file for demo.query collection (precedence test)
with open(os.path.join(queries_path, 'demo.query.0.0.1.yml'), 'w') as f:
f.write(EXTERNAL_QUERY_FOR_DEMO_QUERY_CONTENT)
yield base
# Cleanup: only remove the collection we created, not the entire vendor root
shutil.rmtree(base, ignore_errors=True)
@pytest.fixture(autouse=True)
def cleanup_test_data():
"""Clean up EventQuery and IndirectManagedNodeAudit records after each test."""
yield
EventQuery.objects.filter(fqcn='demo.external').delete()
EventQuery.objects.filter(fqcn='demo.query').delete()
IndirectManagedNodeAudit.objects.filter(job__name__icontains='external_query').delete()
# --- Helpers ---
def run_external_query_job(run_job_from_playbook, live_tmp_folder, test_name, project_dir, jt_params=None):
"""Run a job and return the Job object after waiting for indirect host processing."""
scm_url = f'file://{live_tmp_folder}/{project_dir}'
run_job_from_playbook(test_name, 'run_task.yml', scm_url=scm_url, jt_params=jt_params)
job = Job.objects.filter(name__icontains=test_name).order_by('-created').first()
assert job is not None, f'Job not found for test {test_name}'
wait_for_events(job)
return job
def wait_for_indirect_processing(job, expect_records=True, timeout=5):
"""Wait for indirect host processing to complete.
Follows the same pattern as test_indirect_host_counting.py:53-72.
"""
# Ensure indirect host processing runs (wait_for_events already called by caller)
job.refresh_from_db()
if job.event_queries_processed is False:
save_indirect_host_entries.delay(job.id, wait_for_events=False)
if expect_records:
# Poll for audit records to appear
for _ in range(20):
if IndirectManagedNodeAudit.objects.filter(job=job).exists():
break
time.sleep(0.25)
else:
raise RuntimeError(f'No IndirectManagedNodeAudit records populated for job_id={job.id}')
else:
# For negative tests, wait a reasonable time to confirm no records appear
time.sleep(timeout)
job.refresh_from_db()
# --- AC8.1: External query populates IndirectManagedNodeAudit correctly ---
def test_external_query_populates_audit_table(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.1: Job using demo.external.example with external query file populates
IndirectManagedNodeAudit table correctly.
Uses demo.external v1.0.0 with exact-match external query file demo.external.1.0.0.yml.
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_1',
'test_host_query_external_v1_0_0',
)
wait_for_indirect_processing(job, expect_records=True)
# Verify installed_collections captured demo.external
assert 'demo.external' in job.installed_collections
assert 'host_query' in job.installed_collections['demo.external']
# Verify IndirectManagedNodeAudit records
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
assert host_audit.canonical_facts == {'host_name': 'foo_host_default'}
assert host_audit.facts == {'device_type': 'Fake Host'}
assert host_audit.name == 'vm-foo'
assert host_audit.organization == job.organization
assert 'demo.external.example' in host_audit.events
# --- AC8.2: Precedence - embedded query takes precedence over external ---
def test_embedded_query_takes_precedence(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.2: When collection has both embedded and external query files,
the embedded query takes precedence.
Uses demo.query v0.0.1 which HAS an embedded query (extensions/audit/event_query.yml).
An external query (demo.query.0.0.1.yml) also exists but uses a different jq expression
(no device_type in facts). By checking the audit record's facts, we verify which query was used.
"""
# Run with demo.query collection (has embedded query)
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_2',
'test_host_query',
)
wait_for_indirect_processing(job, expect_records=True)
# Verify the embedded query was used (includes device_type in facts)
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
assert host_audit.facts == {'device_type': 'Fake Host'}, (
'Expected embedded query output (with device_type). ' 'If facts is {}, the external query was incorrectly used instead.'
)
# --- AC8.3: Version fallback to compatible version ---
def test_fallback_to_compatible_version(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.3: Job using collection version with no exact query file falls back
correctly to compatible version.
Uses demo.external v1.5.0. No demo.external.1.5.0.yml exists, but
demo.external.1.1.0.yml is available (same major version, highest <= 1.5.0).
The fallback should find and use the 1.1.0 query.
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_3',
'test_host_query_external_v1_5_0',
)
wait_for_indirect_processing(job, expect_records=True)
# Verify installed_collections captured demo.external at v1.5.0
assert 'demo.external' in job.installed_collections
assert job.installed_collections['demo.external']['version'] == '1.5.0'
# Verify IndirectManagedNodeAudit records were created via fallback
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
assert host_audit.canonical_facts == {'host_name': 'foo_host_default'}
assert host_audit.facts == {'device_type': 'Fake Host'}
assert host_audit.name == 'vm-foo'
# --- AC8.4: Fallback queries don't overcount ---
def test_fallback_does_not_overcount(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.4: Fallback queries don't count MORE nodes than exact-version queries.
Runs two jobs:
1. Exact match scenario (demo.external v1.0.0 -> demo.external.1.0.0.yml)
2. Fallback scenario (demo.external v1.5.0 -> falls back to demo.external.1.1.0.yml)
Verifies that fallback record count <= exact record count.
"""
# Run exact-match job
exact_job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_4_exact',
'test_host_query_external_v1_0_0',
)
wait_for_indirect_processing(exact_job, expect_records=True)
exact_count = IndirectManagedNodeAudit.objects.filter(job=exact_job).count()
# Run fallback job
fallback_job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_4_fallback',
'test_host_query_external_v1_5_0',
)
wait_for_indirect_processing(fallback_job, expect_records=True)
fallback_count = IndirectManagedNodeAudit.objects.filter(job=fallback_job).count()
# Critical safety check: fallback must never count MORE than exact
assert fallback_count <= exact_count, (
f'Overcounting detected! Fallback produced {fallback_count} records ' f'but exact match produced only {exact_count} records.'
)
# Both use the same jq expression and same module, so counts should be equal
assert exact_count == fallback_count
# --- AC8.5: Warning logs contain correct version information ---
def test_fallback_log_contains_version_info(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.5: Warning logs contain correct version information when fallback is used.
Runs a job with verbosity=1 so callback plugin verbose output is captured.
Verifies the log contains the installed version (1.5.0), fallback version (1.1.0),
and collection FQCN (demo.external).
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_5',
'test_host_query_external_v1_5_0',
jt_params={'verbosity': 1},
)
wait_for_indirect_processing(job, expect_records=True)
# Get job stdout to check for fallback log message
stdout = unified_job_stdout(job)
# The callback plugin emits: "Using external query {version_used} for {fqcn} v{ver}."
assert '1.1.0' in stdout, f'Fallback version 1.1.0 not found in job stdout. stdout:\n{stdout}'
assert 'demo.external' in stdout, f'Collection FQCN demo.external not found in job stdout. stdout:\n{stdout}'
assert '1.5.0' in stdout, f'Installed version 1.5.0 not found in job stdout. stdout:\n{stdout}'
# --- AC8.6: No counting when no compatible fallback exists ---
def test_no_counting_without_compatible_fallback(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
"""AC8.6: No counting occurs when no compatible fallback exists.
Uses demo.external v3.0.0 with only v1.x external query files available.
Since major versions differ (3 vs 1), no fallback should occur and no
IndirectManagedNodeAudit records should be created.
"""
job = run_external_query_job(
run_job_from_playbook,
live_tmp_folder,
'external_query_ac8_6',
'test_host_query_external_v3_0_0',
)
wait_for_indirect_processing(job, expect_records=False)
# No audit records should exist for this job
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 0, (
'IndirectManagedNodeAudit records were created despite no compatible ' 'fallback existing for demo.external v3.0.0 (only v1.x queries available).'
)

View File

@@ -1,49 +0,0 @@
import pytest
from collections import OrderedDict
from unittest import mock
from rest_framework.exceptions import ValidationError
from awx.api.fields import DeprecatedCredentialField
class TestDeprecatedCredentialField:
"""Test that DeprecatedCredentialField handles unexpected input types gracefully."""
def test_dict_value_raises_validation_error(self):
"""Passing a dict instead of an integer should return a 400 validation error, not a 500 TypeError."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value({"username": "admin", "password": "secret"})
def test_ordered_dict_value_raises_validation_error(self):
"""Passing an OrderedDict should return a 400 validation error, not a 500 TypeError."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value(OrderedDict([("username", "admin")]))
def test_list_value_raises_validation_error(self):
"""Passing a list should return a 400 validation error, not a 500 TypeError."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value([1, 2, 3])
def test_string_value_raises_validation_error(self):
"""Passing a non-numeric string should return a 400 validation error."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value("not_a_number")
@mock.patch('awx.api.fields.Credential.objects')
def test_valid_integer_value_works(self, mock_cred_objects):
"""Passing a valid integer PK should work when the credential exists."""
mock_cred_objects.get.return_value = mock.MagicMock()
field = DeprecatedCredentialField()
assert field.to_internal_value(42) == 42
@mock.patch('awx.api.fields.Credential.objects')
def test_valid_string_integer_value_works(self, mock_cred_objects):
"""Passing a numeric string PK should work when the credential exists."""
mock_cred_objects.get.return_value = mock.MagicMock()
field = DeprecatedCredentialField()
assert field.to_internal_value("42") == 42

View File

@@ -1,4 +1,3 @@
import copy
import warnings import warnings
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@@ -9,7 +8,6 @@ from awx.api.schema import (
AuthenticatedSpectacularAPIView, AuthenticatedSpectacularAPIView,
AuthenticatedSpectacularSwaggerView, AuthenticatedSpectacularSwaggerView,
AuthenticatedSpectacularRedocView, AuthenticatedSpectacularRedocView,
filter_credential_type_schema,
) )
@@ -273,152 +271,3 @@ class TestAuthenticatedSchemaViews:
def test_authenticated_spectacular_redoc_view_requires_authentication(self): def test_authenticated_spectacular_redoc_view_requires_authentication(self):
"""Test that AuthenticatedSpectacularRedocView requires authentication.""" """Test that AuthenticatedSpectacularRedocView requires authentication."""
assert IsAuthenticated in AuthenticatedSpectacularRedocView.permission_classes assert IsAuthenticated in AuthenticatedSpectacularRedocView.permission_classes
class TestFilterCredentialTypeSchema:
"""Unit tests for filter_credential_type_schema postprocessing hook."""
def test_filters_both_schemas_correctly(self):
"""Test that both CredentialTypeRequest and PatchedCredentialTypeRequest schemas are filtered."""
result = {
'components': {
'schemas': {
'CredentialTypeRequest': {
'properties': {
'kind': {
'enum': [
'ssh',
'vault',
'net',
'scm',
'cloud',
'registry',
'token',
'insights',
'external',
'kubernetes',
'galaxy',
'cryptography',
None,
],
'type': 'string',
}
}
},
'PatchedCredentialTypeRequest': {
'properties': {
'kind': {
'enum': [
'ssh',
'vault',
'net',
'scm',
'cloud',
'registry',
'token',
'insights',
'external',
'kubernetes',
'galaxy',
'cryptography',
None,
],
'type': 'string',
}
}
},
}
}
}
returned = filter_credential_type_schema(result, None, None, None)
# POST/PUT schema: no None (required field)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['description'] == "* `cloud` - Cloud\\n* `net` - Network"
# PATCH schema: includes None (optional field)
assert result['components']['schemas']['PatchedCredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net', None]
assert result['components']['schemas']['PatchedCredentialTypeRequest']['properties']['kind']['description'] == "* `cloud` - Cloud\\n* `net` - Network"
# Other properties should be preserved
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['type'] == 'string'
# Function should return the result
assert returned is result
def test_handles_empty_result(self):
"""Test graceful handling when result dict is empty."""
result = {}
original = copy.deepcopy(result)
returned = filter_credential_type_schema(result, None, None, None)
assert result == original
assert returned is result
def test_handles_missing_enum(self):
"""Test that schemas without enum key are not modified."""
result = {'components': {'schemas': {'CredentialTypeRequest': {'properties': {'kind': {'type': 'string', 'description': 'Some description'}}}}}}
original = copy.deepcopy(result)
filter_credential_type_schema(result, None, None, None)
assert result == original
def test_filters_only_target_schemas(self):
"""Test that only CredentialTypeRequest schemas are modified, not others."""
result = {
'components': {
'schemas': {
'CredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'cloud', 'net', None]}}},
'OtherSchema': {'properties': {'kind': {'enum': ['option1', 'option2']}}},
}
}
}
other_schema_before = copy.deepcopy(result['components']['schemas']['OtherSchema'])
filter_credential_type_schema(result, None, None, None)
# CredentialTypeRequest should be filtered (no None for required field)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
# OtherSchema should be unchanged
assert result['components']['schemas']['OtherSchema'] == other_schema_before
def test_handles_only_one_schema_present(self):
"""Test that function works when only one target schema is present."""
result = {'components': {'schemas': {'CredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'cloud', 'net', None]}}}}}}
filter_credential_type_schema(result, None, None, None)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
def test_handles_missing_properties(self):
"""Test graceful handling when schema has no properties key."""
result = {'components': {'schemas': {'CredentialTypeRequest': {}}}}
original = copy.deepcopy(result)
filter_credential_type_schema(result, None, None, None)
assert result == original
def test_differentiates_required_vs_optional_fields(self):
"""Test that CredentialTypeRequest excludes None but PatchedCredentialTypeRequest includes it."""
result = {
'components': {
'schemas': {
'CredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'vault', 'net', 'scm', 'cloud', 'registry', None]}}},
'PatchedCredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'vault', 'net', 'scm', 'cloud', 'registry', None]}}},
}
}
}
filter_credential_type_schema(result, None, None, None)
# POST/PUT schema: no None (required field)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
# PATCH schema: includes None (optional field)
assert result['components']['schemas']['PatchedCredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net', None]

View File

@@ -1,92 +0,0 @@
import io
import pytest
from django.core.management.base import CommandError
from awx.main.management.commands import dispatcherctl
@pytest.fixture(autouse=True)
def clear_dispatcher_env(monkeypatch, mocker):
monkeypatch.delenv('DISPATCHERD_CONFIG_FILE', raising=False)
mocker.patch.object(dispatcherctl.logging, 'basicConfig')
mocker.patch.object(dispatcherctl, 'connection', mocker.Mock(vendor='postgresql'))
def test_dispatcherctl_runs_control_with_generated_config(mocker):
command = dispatcherctl.Command()
command.stdout = io.StringIO()
data = {'foo': 'bar'}
mocker.patch.object(dispatcherctl, '_build_command_data_from_args', return_value=data)
dispatcher_setup = mocker.patch.object(dispatcherctl, 'dispatcher_setup')
config_data = {'setting': 'value'}
mocker.patch.object(dispatcherctl, 'get_dispatcherd_config', return_value=config_data)
control = mocker.Mock()
control.control_with_reply.return_value = [{'status': 'ok'}]
mocker.patch.object(dispatcherctl, 'get_control_from_settings', return_value=control)
mocker.patch.object(dispatcherctl.yaml, 'dump', return_value='payload\n')
command.handle(
command='running',
config=dispatcherctl.DEFAULT_CONFIG_FILE,
expected_replies=1,
log_level='INFO',
)
dispatcher_setup.assert_called_once_with(config_data)
control.control_with_reply.assert_called_once_with('running', data=data, expected_replies=1)
assert command.stdout.getvalue() == 'payload\n'
def test_dispatcherctl_rejects_custom_config_path():
command = dispatcherctl.Command()
command.stdout = io.StringIO()
with pytest.raises(CommandError):
command.handle(
command='running',
config='/tmp/dispatcher.yml',
expected_replies=1,
log_level='INFO',
)
def test_dispatcherctl_rejects_sqlite_db(mocker):
command = dispatcherctl.Command()
command.stdout = io.StringIO()
mocker.patch.object(dispatcherctl, 'connection', mocker.Mock(vendor='sqlite'))
with pytest.raises(CommandError, match='sqlite3'):
command.handle(
command='running',
config=dispatcherctl.DEFAULT_CONFIG_FILE,
expected_replies=1,
log_level='INFO',
)
def test_dispatcherctl_raises_when_replies_missing(mocker):
command = dispatcherctl.Command()
command.stdout = io.StringIO()
mocker.patch.object(dispatcherctl, '_build_command_data_from_args', return_value={})
mocker.patch.object(dispatcherctl, 'dispatcher_setup')
mocker.patch.object(dispatcherctl, 'get_dispatcherd_config', return_value={})
control = mocker.Mock()
control.control_with_reply.return_value = [{'status': 'ok'}]
mocker.patch.object(dispatcherctl, 'get_control_from_settings', return_value=control)
mocker.patch.object(dispatcherctl.yaml, 'dump', return_value='- status: ok\n')
with pytest.raises(CommandError):
command.handle(
command='running',
config=dispatcherctl.DEFAULT_CONFIG_FILE,
expected_replies=2,
log_level='INFO',
)
control.control_with_reply.assert_called_once_with('running', data={}, expected_replies=2)

View File

@@ -47,34 +47,3 @@ def test__get_credential_type_class_invalid_params():
assert type(e.value) is ValueError assert type(e.value) is ValueError
assert str(e.value) == 'Expected only apps or app_config to be defined, not both' assert str(e.value) == 'Expected only apps or app_config to be defined, not both'
def test_credential_context_property():
"""Test that credential context property initializes empty dict and persists across accesses."""
ct = CredentialType(name='Test Cred', kind='vault')
cred = Credential(id=1, name='Test Credential', credential_type=ct, inputs={})
# First access should return empty dict
context = cred.context
assert context == {}
# Modify the context
context['test_key'] = 'test_value'
# Second access should return the same dict with modifications
assert cred.context == {'test_key': 'test_value'}
assert cred.context is context # Same object reference
def test_credential_context_property_independent_instances():
"""Test that context property is independent between credential instances."""
ct = CredentialType(name='Test Cred', kind='vault')
cred1 = Credential(id=1, name='Cred 1', credential_type=ct, inputs={})
cred2 = Credential(id=2, name='Cred 2', credential_type=ct, inputs={})
cred1.context['key1'] = 'value1'
cred2.context['key2'] = 'value2'
assert cred1.context == {'key1': 'value1'}
assert cred2.context == {'key2': 'value2'}
assert cred1.context is not cred2.context

View File

@@ -2,7 +2,6 @@
import json import json
import os import os
import pytest import pytest
from unittest import mock
from awx.main.models import ( from awx.main.models import (
Inventory, Inventory,
@@ -100,55 +99,52 @@ def test_start_job_fact_cache_within_timeout(hosts, tmpdir):
def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir): def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
artifacts_dir = str(tmpdir.mkdir("artifacts")) fact_cache = os.path.join(tmpdir, 'facts')
inventory_id = 5 start_fact_cache(hosts, fact_cache, timeout=0)
start_fact_cache(hosts, artifacts_dir=artifacts_dir, timeout=0, inventory_id=inventory_id)
mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
# Remove the fact file for hosts[1] to simulate ansible's clear_facts
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
os.remove(os.path.join(fact_cache_dir, hosts[1].name))
hosts_qs = mock.MagicMock()
hosts_qs.filter.return_value.order_by.return_value.iterator.return_value = iter(hosts)
finish_fact_cache(hosts_qs, artifacts_dir=artifacts_dir, inventory_id=inventory_id)
# hosts[1] should have had its facts cleared (file was missing, job_created=None)
assert hosts[1].ansible_facts == {}
assert hosts[1].ansible_facts_modified > ref_time
# Other hosts should be unmodified (fact files exist but weren't changed by ansible)
for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time
def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
artifacts_dir = str(tmpdir.mkdir("artifacts"))
inventory_id = 5
start_fact_cache(hosts, artifacts_dir=artifacts_dir, timeout=0, inventory_id=inventory_id)
bulk_update = mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id') bulk_update = mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
# Overwrite fact files with invalid JSON and set future mtime # Mock the os.path.exists behavior for host deletion
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache') # Let's assume the fact file for hosts[1] is missing.
mocker.patch('os.path.exists', side_effect=lambda path: hosts[1].name not in path)
# Simulate one host's fact file getting deleted manually
host_to_delete_filepath = os.path.join(fact_cache, hosts[1].name)
# Simulate the file being removed by checking existence first, to avoid FileNotFoundError
if os.path.exists(host_to_delete_filepath):
os.remove(host_to_delete_filepath)
finish_fact_cache(fact_cache)
# Simulate side effects that would normally be applied during bulk update
hosts[1].ansible_facts = {}
hosts[1].ansible_facts_modified = now()
# Verify facts are preserved for hosts with valid cache files
for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time
assert hosts[1].ansible_facts_modified > ref_time
# Current implementation skips the call entirely if hosts_to_update == []
bulk_update.assert_not_called()
def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts')
start_fact_cache(hosts, fact_cache, timeout=0)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
for h in hosts: for h in hosts:
filepath = os.path.join(fact_cache_dir, h.name) filepath = os.path.join(fact_cache, h.name)
with open(filepath, 'w') as f: with open(filepath, 'w') as f:
f.write('not valid json!') f.write('not valid json!')
f.flush() f.flush()
new_modification_time = time.time() + 3600 new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time)) os.utime(filepath, (new_modification_time, new_modification_time))
hosts_qs = mock.MagicMock() finish_fact_cache(fact_cache)
hosts_qs.filter.return_value.order_by.return_value.iterator.return_value = iter(hosts)
finish_fact_cache(hosts_qs, artifacts_dir=artifacts_dir, inventory_id=inventory_id) bulk_update.assert_not_called()
# Invalid JSON should be skipped — no hosts updated
updated_hosts = bulk_update.call_args[0][1]
assert updated_hosts == []

View File

@@ -176,22 +176,22 @@ def test_display_survey_spec_encrypts_default(survey_spec_factory):
@pytest.mark.survey @pytest.mark.survey
@pytest.mark.parametrize( @pytest.mark.parametrize(
"question_type,default,min,max,expect_valid,expect_use,expect_value", "question_type,default,min,max,expect_use,expect_value",
[ [
("text", "", 0, 0, True, False, 'N/A'), # valid but empty default not sent for optional question ("text", "", 0, 0, True, ''), # default used
("text", "", 1, 0, False, False, 'N/A'), # value less than min length ("text", "", 1, 0, False, 'N/A'), # value less than min length
("password", "", 1, 0, False, False, 'N/A'), # passwords behave the same as text ("password", "", 1, 0, False, 'N/A'), # passwords behave the same as text
("multiplechoice", "", 0, 0, False, False, 'N/A'), # historical bug ("multiplechoice", "", 0, 0, False, 'N/A'), # historical bug
("multiplechoice", "zeb", 0, 0, False, False, 'N/A'), # zeb not in choices ("multiplechoice", "zeb", 0, 0, False, 'N/A'), # zeb not in choices
("multiplechoice", "coffee", 0, 0, True, True, 'coffee'), ("multiplechoice", "coffee", 0, 0, True, 'coffee'),
("multiselect", None, 0, 0, False, False, 'N/A'), # NOTE: Behavior is arguable, value of [] may be prefered ("multiselect", None, 0, 0, False, 'N/A'), # NOTE: Behavior is arguable, value of [] may be prefered
("multiselect", "", 0, 0, False, False, 'N/A'), ("multiselect", "", 0, 0, False, 'N/A'),
("multiselect", ["zeb"], 0, 0, False, False, 'N/A'), ("multiselect", ["zeb"], 0, 0, False, 'N/A'),
("multiselect", ["milk"], 0, 0, True, True, ["milk"]), ("multiselect", ["milk"], 0, 0, True, ["milk"]),
("multiselect", ["orange\nmilk"], 0, 0, False, False, 'N/A'), # historical bug ("multiselect", ["orange\nmilk"], 0, 0, False, 'N/A'), # historical bug
], ],
) )
def test_optional_survey_question_defaults(survey_spec_factory, question_type, default, min, max, expect_valid, expect_use, expect_value): def test_optional_survey_question_defaults(survey_spec_factory, question_type, default, min, max, expect_use, expect_value):
spec = survey_spec_factory( spec = survey_spec_factory(
[ [
{ {
@@ -208,7 +208,7 @@ def test_optional_survey_question_defaults(survey_spec_factory, question_type, d
jt = JobTemplate(name="test-jt", survey_spec=spec, survey_enabled=True) jt = JobTemplate(name="test-jt", survey_spec=spec, survey_enabled=True)
defaulted_extra_vars = jt._update_unified_job_kwargs({}, {}) defaulted_extra_vars = jt._update_unified_job_kwargs({}, {})
element = spec['spec'][0] element = spec['spec'][0]
if expect_valid: if expect_use:
assert jt._survey_element_validation(element, {element['variable']: element['default']}) == [] assert jt._survey_element_validation(element, {element['variable']: element['default']}) == []
else: else:
assert jt._survey_element_validation(element, {element['variable']: element['default']}) assert jt._survey_element_validation(element, {element['variable']: element['default']})
@@ -218,28 +218,6 @@ def test_optional_survey_question_defaults(survey_spec_factory, question_type, d
assert 'c' not in defaulted_extra_vars['extra_vars'] assert 'c' not in defaulted_extra_vars['extra_vars']
@pytest.mark.survey
def test_optional_survey_empty_default_with_runtime_extra_var(survey_spec_factory):
"""When a user explicitly provides an empty string at runtime for an optional
survey question, the variable should still be included in extra_vars."""
spec = survey_spec_factory(
[
{
"required": False,
"default": "",
"choices": "",
"variable": "c",
"min": 0,
"max": 0,
"type": "text",
},
]
)
jt = JobTemplate(name="test-jt", survey_spec=spec, survey_enabled=True)
defaulted_extra_vars = jt._update_unified_job_kwargs({}, {'extra_vars': json.dumps({'c': ''})})
assert json.loads(defaulted_extra_vars['extra_vars'])['c'] == ''
@pytest.mark.survey @pytest.mark.survey
@pytest.mark.parametrize( @pytest.mark.parametrize(
"question_type,default,maxlen,kwargs,expected", "question_type,default,maxlen,kwargs,expected",

View File

@@ -1,3 +1,4 @@
import pytest
from unittest import mock from unittest import mock
from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory
@@ -21,6 +22,52 @@ def test_unified_job_workflow_attributes():
assert job.workflow_job_id == 1 assert job.workflow_job_id == 1
def mock_on_commit(f):
f()
@pytest.fixture
def unified_job(mocker):
mocker.patch.object(UnifiedJob, 'can_cancel', return_value=True)
j = UnifiedJob()
j.status = 'pending'
j.cancel_flag = None
j.save = mocker.MagicMock()
j.websocket_emit_status = mocker.MagicMock()
j.fallback_cancel = mocker.MagicMock()
return j
def test_cancel(unified_job):
with mock.patch('awx.main.models.unified_jobs.connection.on_commit', wraps=mock_on_commit):
unified_job.cancel()
assert unified_job.cancel_flag is True
assert unified_job.status == 'canceled'
assert unified_job.job_explanation == ''
# Note: the websocket emit status check is just reflecting the state of the current code.
# Some more thought may want to go into only emitting canceled if/when the job record
# status is changed to canceled. Unlike, currently, where it's emitted unconditionally.
unified_job.websocket_emit_status.assert_called_with("canceled")
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
((), {'update_fields': ['cancel_flag', 'start_args']}),
((), {'update_fields': ['status']}),
]
def test_cancel_job_explanation(unified_job):
job_explanation = 'giggity giggity'
with mock.patch('awx.main.models.unified_jobs.connection.on_commit'):
unified_job.cancel(job_explanation=job_explanation)
assert unified_job.job_explanation == job_explanation
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
((), {'update_fields': ['cancel_flag', 'start_args', 'job_explanation']}),
((), {'update_fields': ['status']}),
]
def test_organization_copy_to_jobs(): def test_organization_copy_to_jobs():
""" """
All unified job types should infer their organization from their template organization All unified job types should infer their organization from their template organization

View File

@@ -226,140 +226,3 @@ def test_send_messages_with_additional_headers():
allow_redirects=False, allow_redirects=False,
) )
assert sent_messages == 1 assert sent_messages == 1
def test_send_messages_with_redirects_ok():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock:
# First two calls return redirects, third call returns 200
requests_mock.post.side_effect = [
mock.Mock(status_code=301, headers={"Location": "http://redirect1.com"}),
mock.Mock(status_code=307, headers={"Location": "http://redirect2.com"}),
mock.Mock(status_code=200),
]
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
assert requests_mock.post.call_count == 3
requests_mock.post.assert_called_with(
url='http://redirect2.com',
auth=None,
data=json.dumps({'text': 'test body'}, ensure_ascii=False).encode('utf-8'),
headers={'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'},
verify=True,
allow_redirects=False,
)
assert sent_messages == 1
def test_send_messages_with_redirects_blank():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock, mock.patch('awx.main.notifications.webhook_backend.logger') as logger_mock:
# First call returns a redirect with Location header, second call returns 301 but NO Location header
requests_mock.post.side_effect = [
mock.Mock(status_code=301, headers={"Location": "http://redirect1.com"}),
mock.Mock(status_code=301, headers={}), # 301 with no Location header
]
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None, fail_silently=True)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
# Should make 2 requests (initial + 1 redirect attempt)
assert requests_mock.post.call_count == 2
# The error message should be logged
logger_mock.error.assert_called_once()
error_call_args = logger_mock.error.call_args[0][0]
assert "redirect to a blank URL" in error_call_args
assert sent_messages == 0
def test_send_messages_with_redirects_max_retries_exceeded():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock, mock.patch('awx.main.notifications.webhook_backend.logger') as logger_mock:
# Return MAX_RETRIES (5) redirect responses to exceed the retry limit
requests_mock.post.side_effect = [
mock.Mock(status_code=301, headers={"Location": "http://redirect1.com"}),
mock.Mock(status_code=301, headers={"Location": "http://redirect2.com"}),
mock.Mock(status_code=307, headers={"Location": "http://redirect3.com"}),
mock.Mock(status_code=301, headers={"Location": "http://redirect4.com"}),
mock.Mock(status_code=307, headers={"Location": "http://redirect5.com"}),
]
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None, fail_silently=True)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
# Should make exactly 5 requests (MAX_RETRIES)
assert requests_mock.post.call_count == 5
# The error message should be logged for exceeding max retries
logger_mock.error.assert_called_once()
error_call_args = logger_mock.error.call_args[0][0]
assert "max number of retries" in error_call_args
assert "[5]" in error_call_args
assert sent_messages == 0
def test_send_messages_with_error_status_code():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock, mock.patch('awx.main.notifications.webhook_backend.logger') as logger_mock:
# Return a 404 error status code
requests_mock.post.return_value = mock.Mock(status_code=404)
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None, fail_silently=True)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
# Should make exactly 1 request
assert requests_mock.post.call_count == 1
# The error message should be logged
logger_mock.error.assert_called_once()
error_call_args = logger_mock.error.call_args[0][0]
assert "Error sending webhook notification: 404" in error_call_args
assert sent_messages == 0

View File

@@ -1,19 +1,20 @@
import pytest import pytest
from django.conf import settings from django.conf import settings
from datetime import timedelta
@pytest.mark.parametrize( @pytest.mark.parametrize(
"task_name", "job_name,function_path",
[ [
'awx.main.tasks.system.awx_periodic_scheduler', ('tower_scheduler', 'awx.main.tasks.system.awx_periodic_scheduler'),
], ],
) )
def test_DISPATCHER_SCHEDULE(mocker, task_name): def test_CELERYBEAT_SCHEDULE(mocker, job_name, function_path):
assert task_name in settings.DISPATCHER_SCHEDULE assert job_name in settings.CELERYBEAT_SCHEDULE
assert 'schedule' in settings.DISPATCHER_SCHEDULE[task_name] assert 'schedule' in settings.CELERYBEAT_SCHEDULE[job_name]
assert type(settings.DISPATCHER_SCHEDULE[task_name]['schedule']) in (int, float) assert type(settings.CELERYBEAT_SCHEDULE[job_name]['schedule']) is timedelta
assert settings.DISPATCHER_SCHEDULE[task_name]['task'] == task_name assert settings.CELERYBEAT_SCHEDULE[job_name]['task'] == function_path
# Ensures that the function exists # Ensures that the function exists
mocker.patch(task_name) mocker.patch(function_path)

View File

@@ -1,4 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
import tempfile
import shutil
import pytest import pytest
from unittest import mock from unittest import mock
@@ -14,72 +18,25 @@ from awx.main.models import (
Job, Job,
Organization, Organization,
Project, Project,
JobTemplate,
UnifiedJobTemplate,
InstanceGroup,
ExecutionEnvironment,
ProjectUpdate,
InventoryUpdate,
InventorySource,
AdHocCommand,
) )
from awx.main.tasks import jobs from awx.main.tasks import jobs
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
@pytest.fixture @pytest.fixture
def private_data_dir(tmp_path): def private_data_dir():
private_data = tmp_path / 'awx_pdd' private_data = tempfile.mkdtemp(prefix='awx_')
private_data.mkdir()
for subfolder in ('inventory', 'env'): for subfolder in ('inventory', 'env'):
(private_data / subfolder).mkdir() runner_subfolder = os.path.join(private_data, subfolder)
return str(private_data) os.makedirs(runner_subfolder, exist_ok=True)
yield private_data
shutil.rmtree(private_data, True)
@pytest.fixture
def job_template_with_credentials():
"""
Factory fixture that creates a job template with specified credentials.
Usage:
job = job_template_with_credentials(ssh_cred, vault_cred)
"""
def _create_job_template(
*credentials, org_name='test-org', project_name='test-project', inventory_name='test-inventory', jt_name='test-jt', playbook='test.yml'
):
"""
Create a job template with the given credentials.
Args:
*credentials: Variable number of Credential objects to attach to the job template
org_name: Name for the organization
project_name: Name for the project
inventory_name: Name for the inventory
jt_name: Name for the job template
playbook: Playbook filename
Returns:
Job instance created from the job template
"""
org = Organization.objects.create(name=org_name)
proj = Project.objects.create(name=project_name, organization=org)
inv = Inventory.objects.create(name=inventory_name, organization=org)
jt = JobTemplate.objects.create(name=jt_name, project=proj, inventory=inv, playbook=playbook)
if credentials:
jt.credentials.add(*credentials)
return jt.create_unified_job()
return _create_job_template
@mock.patch('awx.main.tasks.facts.settings') @mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True) @mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment): def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
# Create mocked inventory and host queryset # Create mocked inventory and host queryset
inventory = mock.MagicMock(spec=Inventory, pk=1, kind='') inventory = mock.MagicMock(spec=Inventory, pk=1)
host1 = mock.MagicMock(spec=Host, id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory) host1 = mock.MagicMock(spec=Host, id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
host2 = mock.MagicMock(spec=Host, id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory) host2 = mock.MagicMock(spec=Host, id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
@@ -96,16 +53,12 @@ def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, pri
proj = mock.MagicMock(spec=Project, pk=1, organization=org) proj = mock.MagicMock(spec=Project, pk=1, organization=org)
job = mock.MagicMock( job = mock.MagicMock(
spec=Job, spec=Job,
pk=1,
id=1,
use_fact_cache=True, use_fact_cache=True,
project=proj, project=proj,
organization=org, organization=org,
job_slice_number=1, job_slice_number=1,
job_slice_count=1, job_slice_count=1,
inventory=inventory, inventory=inventory,
inventory_id=inventory.pk,
created=now(),
execution_environment=execution_environment, execution_environment=execution_environment,
) )
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job) job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
@@ -137,11 +90,9 @@ def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, pri
@mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id') @mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
@mock.patch('awx.main.tasks.facts.settings') @mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True) @mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts_deleted_sliced( def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
mock_create_partition, mock_facts_settings, mock_bulk_update_sorted_by_id, private_data_dir, execution_environment
):
# Fully mocked inventory # Fully mocked inventory
mock_inventory = mock.MagicMock(spec=Inventory, pk=1, kind='') mock_inventory = mock.MagicMock(spec=Inventory)
# Create 999 mocked Host instances # Create 999 mocked Host instances
hosts = [] hosts = []
@@ -167,8 +118,6 @@ def test_pre_post_run_hook_facts_deleted_sliced(
# Mock job object # Mock job object
job = mock.MagicMock(spec=Job) job = mock.MagicMock(spec=Job)
job.pk = 2
job.id = 2
job.use_fact_cache = True job.use_fact_cache = True
job.project = proj job.project = proj
job.organization = org job.organization = org
@@ -176,8 +125,6 @@ def test_pre_post_run_hook_facts_deleted_sliced(
job.job_slice_count = 3 job.job_slice_count = 3
job.execution_environment = execution_environment job.execution_environment = execution_environment
job.inventory = mock_inventory job.inventory = mock_inventory
job.inventory_id = mock_inventory.pk
job.created = now()
job.job_env.get.return_value = private_data_dir job.job_env.get.return_value = private_data_dir
# Bind actual method for host filtering # Bind actual method for host filtering
@@ -241,352 +188,3 @@ def test_invalid_host_facts(mock_facts_settings, bulk_update_sorted_by_id, priva
with pytest.raises(pytest.fail.Exception): with pytest.raises(pytest.fail.Exception):
if failures: if failures:
pytest.fail(f" {len(failures)} facts cleared failures : {','.join(failures)}") pytest.fail(f" {len(failures)} facts cleared failures : {','.join(failures)}")
@pytest.mark.parametrize(
"job_attrs,expected_claims",
[
(
{
'id': 100,
'name': 'Test Job',
'job_type': 'run',
'launch_type': 'manual',
'playbook': 'site.yml',
'organization': Organization(id=1, name='Test Org'),
'inventory': Inventory(id=2, name='Test Inventory'),
'project': Project(id=3, name='Test Project'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'job_template': JobTemplate(id=5, name='Test Job Template'),
'unified_job_template': UnifiedJobTemplate(pk=6, id=6, name='Test Unified Job Template'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 100,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Test Job',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: 'site.yml',
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_INVENTORY_NAME: 'Test Inventory',
AutomationControllerJobScope.CLAIM_INVENTORY_ID: 2,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_PROJECT_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_PROJECT_ID: 3,
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: 'Test Job Template',
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: 5,
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME: 'Test Unified Job Template',
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID: 6,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{'id': 100, 'name': 'Test', 'job_type': 'run', 'launch_type': 'manual', 'organization': Organization(id=1, name='')},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 100,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Test',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: '',
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: '',
},
),
],
)
def test_populate_claims_for_workload(job_attrs, expected_claims):
job = Job()
for attr, value in job_attrs.items():
setattr(job, attr, value)
claims = jobs.populate_claims_for_workload(job)
assert claims == expected_claims
@pytest.mark.parametrize(
"workload_attrs,expected_claims",
[
(
{
'id': 200,
'name': 'Git Sync',
'job_type': 'check',
'launch_type': 'sync',
'organization': Organization(id=1, name='Test Org'),
'project': Project(pk=3, id=3, name='Test Project'),
'unified_job_template': Project(pk=3, id=3, name='Test Project'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 200,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Git Sync',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'check',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'sync',
AutomationControllerJobScope.CLAIM_LAUNCHED_BY_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_LAUNCHED_BY_ID: 3,
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_PROJECT_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_PROJECT_ID: 3,
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID: 3,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{
'id': 201,
'name': 'Minimal Project Update',
'job_type': 'run',
'launch_type': 'manual',
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 201,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Minimal Project Update',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
},
),
],
)
def test_populate_claims_for_project_update(workload_attrs, expected_claims):
project_update = ProjectUpdate()
for attr, value in workload_attrs.items():
setattr(project_update, attr, value)
claims = jobs.populate_claims_for_workload(project_update)
assert claims == expected_claims
@pytest.mark.parametrize(
"workload_attrs,expected_claims",
[
(
{
'id': 300,
'name': 'AWS Sync',
'launch_type': 'scheduled',
'organization': Organization(id=1, name='Test Org'),
'inventory': Inventory(id=2, name='AWS Inventory'),
'unified_job_template': InventorySource(pk=8, id=8, name='AWS Source'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 300,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'AWS Sync',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'scheduled',
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_INVENTORY_NAME: 'AWS Inventory',
AutomationControllerJobScope.CLAIM_INVENTORY_ID: 2,
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME: 'AWS Source',
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID: 8,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{
'id': 301,
'name': 'Minimal Inventory Update',
'launch_type': 'manual',
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 301,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Minimal Inventory Update',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
},
),
],
)
def test_populate_claims_for_inventory_update(workload_attrs, expected_claims):
inventory_update = InventoryUpdate()
for attr, value in workload_attrs.items():
setattr(inventory_update, attr, value)
claims = jobs.populate_claims_for_workload(inventory_update)
assert claims == expected_claims
@pytest.mark.parametrize(
"workload_attrs,expected_claims",
[
(
{
'id': 400,
'name': 'Ping All Hosts',
'job_type': 'run',
'launch_type': 'manual',
'organization': Organization(id=1, name='Test Org'),
'inventory': Inventory(id=2, name='Test Inventory'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 400,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Ping All Hosts',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_INVENTORY_NAME: 'Test Inventory',
AutomationControllerJobScope.CLAIM_INVENTORY_ID: 2,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{
'id': 401,
'name': 'Minimal Ad Hoc',
'job_type': 'run',
'launch_type': 'manual',
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 401,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Minimal Ad Hoc',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
},
),
],
)
def test_populate_claims_for_adhoc_command(workload_attrs, expected_claims):
adhoc_command = AdHocCommand()
for attr, value in workload_attrs.items():
setattr(adhoc_command, attr, value)
claims = jobs.populate_claims_for_workload(adhoc_command)
assert claims == expected_claims
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client):
"""retrieve_workload_identity_jwt returns the JWT string from the client."""
mock_client = mock.MagicMock()
mock_response = mock.MagicMock()
mock_response.jwt = 'eyJ.test.jwt'
mock_client.request_workload_jwt.return_value = mock_response
mock_get_client.return_value = mock_client
unified_job = Job()
unified_job.id = 42
unified_job.name = 'Test Job'
unified_job.launch_type = 'manual'
unified_job.organization = Organization(id=1, name='Test Org')
unified_job.unified_job_template = None
unified_job.instance_group = None
result = jobs.retrieve_workload_identity_jwt(unified_job, audience='https://api.example.com', scope='aap_controller_automation_job')
assert result == 'eyJ.test.jwt'
mock_client.request_workload_jwt.assert_called_once()
call_kwargs = mock_client.request_workload_jwt.call_args[1]
assert call_kwargs['audience'] == 'https://api.example.com'
assert call_kwargs['scope'] == 'aap_controller_automation_job'
assert 'claims' in call_kwargs
assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_ID] == 42
assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_NAME] == 'Test Job'
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_client):
"""retrieve_workload_identity_jwt passes audience and scope to the client."""
mock_client = mock.MagicMock()
mock_client.request_workload_jwt.return_value = mock.MagicMock(jwt='token')
mock_get_client.return_value = mock_client
unified_job = mock.MagicMock()
audience = 'custom_audience'
scope = 'custom_scope'
with mock.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 1}):
jobs.retrieve_workload_identity_jwt(unified_job, audience=audience, scope=scope)
mock_client.request_workload_jwt.assert_called_once_with(claims={'job_id': 1}, scope=scope, audience=audience)
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client):
"""retrieve_workload_identity_jwt passes workload_ttl_seconds when provided."""
mock_client = mock.Mock()
mock_client.request_workload_jwt.return_value = mock.Mock(jwt='token')
mock_get_client.return_value = mock_client
unified_job = mock.MagicMock()
with mock.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 1}):
jobs.retrieve_workload_identity_jwt(
unified_job,
audience='https://vault.example.com',
scope='aap_controller_automation_job',
workload_ttl_seconds=3600,
)
mock_client.request_workload_jwt.assert_called_once_with(
claims={'job_id': 1},
scope='aap_controller_automation_job',
audience='https://vault.example.com',
workload_ttl_seconds=3600,
)
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_get_client):
"""retrieve_workload_identity_jwt raises RuntimeError when client is None."""
mock_get_client.return_value = None
unified_job = mock.MagicMock()
with pytest.raises(RuntimeError, match="Workload identity client is not configured"):
jobs.retrieve_workload_identity_jwt(unified_job, audience='test_audience', scope='test_scope')
@pytest.mark.parametrize('effective_timeout,expected_ttl', [(3600, 3600), (0, None)])
@mock.patch('awx.main.tasks.jobs.retrieve_workload_identity_jwt')
@mock.patch('awx.main.tasks.jobs.flag_enabled', return_value=True)
def test_populate_workload_identity_tokens_passes_get_instance_timeout_to_client(mock_flag_enabled, mock_retrieve_jwt, effective_timeout, expected_ttl):
"""populate_workload_identity_tokens passes get_instance_timeout() value as workload_ttl_seconds to retrieve_workload_identity_jwt."""
mock_retrieve_jwt.return_value = 'eyJ.test.jwt'
task = jobs.RunJob()
task.instance = mock.MagicMock()
# Minimal credential with workload identity input source
credential_ctx = {}
input_src = mock.MagicMock()
input_src.pk = 1
input_src.source_credential = mock.MagicMock()
input_src.source_credential.get_input.return_value = 'https://vault.example.com'
input_src.source_credential.name = 'vault-cred'
input_src.source_credential.credential_type = mock.MagicMock()
input_src.source_credential.credential_type.inputs = {'fields': [{'id': 'workload_identity_token', 'internal': True}]}
credential = mock.MagicMock()
credential.context = credential_ctx
credential.input_sources = mock.MagicMock()
credential.input_sources.all.return_value = [input_src]
task._credentials = [credential]
with mock.patch.object(task, 'get_instance_timeout', return_value=effective_timeout):
task.populate_workload_identity_tokens()
mock_flag_enabled.assert_called_once_with("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED")
mock_retrieve_jwt.assert_called_once_with(
task.instance,
audience='https://vault.example.com',
scope=AutomationControllerJobScope.name,
workload_ttl_seconds=expected_ttl,
)

View File

@@ -12,10 +12,6 @@ def pytest_sigterm():
pytest_sigterm.called_count += 1 pytest_sigterm.called_count += 1
def pytest_sigusr1():
pytest_sigusr1.called_count += 1
def tmp_signals_for_test(func): def tmp_signals_for_test(func):
""" """
When we run our internal signal handlers, it will call the original signal When we run our internal signal handlers, it will call the original signal
@@ -30,17 +26,13 @@ def tmp_signals_for_test(func):
def wrapper(): def wrapper():
original_sigterm = signal.getsignal(signal.SIGTERM) original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT) original_sigint = signal.getsignal(signal.SIGINT)
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
signal.signal(signal.SIGTERM, pytest_sigterm) signal.signal(signal.SIGTERM, pytest_sigterm)
signal.signal(signal.SIGINT, pytest_sigint) signal.signal(signal.SIGINT, pytest_sigint)
signal.signal(signal.SIGUSR1, pytest_sigusr1)
pytest_sigterm.called_count = 0 pytest_sigterm.called_count = 0
pytest_sigint.called_count = 0 pytest_sigint.called_count = 0
pytest_sigusr1.called_count = 0
func() func()
signal.signal(signal.SIGTERM, original_sigterm) signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint) signal.signal(signal.SIGINT, original_sigint)
signal.signal(signal.SIGUSR1, original_sigusr1)
return wrapper return wrapper
@@ -66,13 +58,11 @@ def test_outer_inner_signal_handling():
assert signal_callback() is False assert signal_callback() is False
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1() f1()
assert signal_callback() is False assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 1 assert pytest_sigterm.called_count == 1
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
@tmp_signals_for_test @tmp_signals_for_test
@@ -97,31 +87,8 @@ def test_inner_outer_signal_handling():
assert signal_callback() is False assert signal_callback() is False
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1() f1()
assert signal_callback() is False assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 1 assert pytest_sigint.called_count == 1
assert pytest_sigusr1.called_count == 0
@tmp_signals_for_test
def test_sigusr1_signal_handling():
@with_signal_handling
def f1():
assert signal_callback() is False
signal_state.set_signal_flag(for_signal=signal.SIGUSR1)
assert signal_callback()
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
assert signal_callback() is False
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1()
assert signal_callback() is False
assert signal.getsignal(signal.SIGUSR1) is original_sigusr1
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 1

View File

@@ -76,9 +76,6 @@ def test_custom_error_messages(schema, given, message):
({'fields': [{'id': 'token', 'label': 'Token', 'secret': 'bad'}]}, False), ({'fields': [{'id': 'token', 'label': 'Token', 'secret': 'bad'}]}, False),
({'fields': [{'id': 'token', 'label': 'Token', 'ask_at_runtime': True}]}, True), ({'fields': [{'id': 'token', 'label': 'Token', 'ask_at_runtime': True}]}, True),
({'fields': [{'id': 'token', 'label': 'Token', 'ask_at_runtime': 'bad'}]}, False), # noqa ({'fields': [{'id': 'token', 'label': 'Token', 'ask_at_runtime': 'bad'}]}, False), # noqa
({'fields': [{'id': 'token', 'label': 'Token', 'internal': True}]}, True),
({'fields': [{'id': 'token', 'label': 'Token', 'internal': False}]}, True),
({'fields': [{'id': 'token', 'label': 'Token', 'internal': 'bad'}]}, False),
({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': 'not-a-list'}]}, False), # noqa ({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': 'not-a-list'}]}, False), # noqa
({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': []}]}, False), ({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': []}]}, False),
({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': ['su', 'sudo']}]}, True), # noqa ({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': ['su', 'sudo']}]}, True), # noqa
@@ -207,68 +204,6 @@ def test_credential_creation_validation_failure(inputs):
assert e.type in (ValidationError, DRFValidationError) assert e.type in (ValidationError, DRFValidationError)
def test_credential_input_field_excludes_internal_fields():
"""Internal fields should be excluded from the schema generated by CredentialInputField,
preventing users from providing values for internally resolved fields."""
type_ = CredentialType(
kind='cloud',
name='SomeCloud',
managed=True,
inputs={
'fields': [
{'id': 'username', 'label': 'Username', 'type': 'string'},
{'id': 'resolved_token', 'label': 'Token', 'type': 'string', 'internal': True},
]
},
)
cred = Credential(credential_type=type_, name="Test Credential", inputs={'username': 'joe'})
field = cred._meta.get_field('inputs')
schema = field.schema(cred)
assert 'username' in schema['properties']
assert 'resolved_token' not in schema['properties']
def test_credential_input_field_rejects_values_for_internal_fields():
"""Users should not be able to provide values for fields marked as internal."""
type_ = CredentialType(
kind='cloud',
name='SomeCloud',
managed=True,
inputs={
'fields': [
{'id': 'username', 'label': 'Username', 'type': 'string'},
{'id': 'resolved_token', 'label': 'Token', 'type': 'string', 'internal': True},
]
},
)
cred = Credential(credential_type=type_, name="Test Credential", inputs={'username': 'joe', 'resolved_token': 'secret'})
field = cred._meta.get_field('inputs')
with pytest.raises(Exception) as e:
field.validate(cred.inputs, cred)
assert e.type in (ValidationError, DRFValidationError)
def test_credential_input_field_accepts_non_internal_fields_only():
"""Credentials with only non-internal field values should validate successfully."""
type_ = CredentialType(
kind='cloud',
name='SomeCloud',
managed=True,
inputs={
'fields': [
{'id': 'username', 'label': 'Username', 'type': 'string'},
{'id': 'resolved_token', 'label': 'Token', 'type': 'string', 'internal': True},
]
},
)
cred = Credential(credential_type=type_, name="Test Credential", inputs={'username': 'joe'})
field = cred._meta.get_field('inputs')
# Should not raise
field.validate(cred.inputs, cred)
def test_implicit_role_field_parents(): def test_implicit_role_field_parents():
"""This assures that every ImplicitRoleField only references parents """This assures that every ImplicitRoleField only references parents
which are relationships that actually exist which are relationships that actually exist

Some files were not shown because too many files have changed in this diff Show More