Compare commits

...

60 Commits

Author SHA1 Message Date
Peter Braun
ed1b5c5519 exclude password fields from skipping if not defined 2026-02-17 12:19:43 +01:00
Peter Braun
d2e51c4124 do not add optional survey fields with empty strings that are not backed by extra_vars 2026-02-17 10:13:33 +01:00
Chris Meyers
08f1507f70 Change remote host finding logic
* When the remote host header values contains a comma separated list,
  only consider the first entry. Previously we considered every item in
  the list.
2026-02-16 15:46:47 -05:00
Stevenson Michel
994a2b3c04 [Devel][AAP-65384]Restoration of Token Authentication for AWX CLI (#16281)
* Added token authentication in logic, arguments, and test
2026-02-16 10:14:31 -05:00
Seth Foster
7ccc14daeb Remove stale api:schema-swagger-ui reference from API root (#16282)
The schema-swagger-ui URL was removed from awx/api/urls/urls.py in
d7eb714859 when docs endpoints moved to DAB's api_documentation app,
but the reverse call in ApiRootView was not removed, causing a
NoReverseMatch error in development mode.

Signed-off-by: Seth Foster <fosterbseth@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-12 20:12:34 +00:00
Seth Foster
9700fb01f2 Fix awx CLI modify command for users with object-level permissions (#16276)
The awx CLI derives available fields for the `modify` command from
the list endpoint's POST action schema. Users with object-level
admin permissions (e.g., Project Admin) but no list-level POST
permission see no field flags, making modify unusable despite having
PUT access on the detail endpoint.

Fall back to the detail endpoint's action schema when POST is not
available on the list endpoint, and prefer PUT over POST when
building modify arguments.

Signed-off-by: Seth Foster <fosterbseth@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-12 13:26:15 -05:00
Seth Foster
c515b86fa6 Bump wheel to address CVE-2026-24049 (#16253)
Signed-off-by: Seth Foster <fosterbseth@gmail.com>
2026-02-12 13:09:25 -05:00
Chris Meyers
01293f1b45 Restore github app lookup tests
* Introduced in PR https://github.com/ansible/awx/pull/16058/changes
  then a later large merge from AAP back into devel removed the changes
* This PR re-introduces the github app lookup migration rename tests
  with the migration names updated and the kind to namespace correction
2026-02-12 11:45:10 -05:00
Chris Meyers
fd847862a7 Fix wrong field rename
* Multiple credentialtype's have the same kind and kind values look
  like: cloud, network, machine, etc.
* namespace is the field that we want to rename
2026-02-12 11:45:10 -05:00
Dave
980d9db192 fix: align pip version constraint in requirements_dev.txt (#16275)
fix: align pip version constraint in requirements_dev.txt with requirements.txt (fixes #16272)

requirements.txt pins pip==25.3 while requirements_dev.txt specified
pip>=21.3,<=24.0, causing ResolutionImpossible when installing both.
Updated requirements_dev.txt to use pip>=25.3 to maintain compatibility.
2026-02-12 11:35:01 -05:00
Alan Rominger
f2438a0e86 Fix server error from PATCH to inventory source (#16274)
Fix server error from PATCH to inventory source, co-authored with Claude opus 4.6
2026-02-11 15:10:32 -05:00
Rodrigo Toshiaki Horie
707f2fa5da Add OpenAPI spec sync workflow (#16267) 2026-02-10 19:13:47 -03:00
Rodrigo Toshiaki Horie
1f18396438 Add CI Checks for syntactically valid OpenAPI Specification (#16266) 2026-02-10 19:13:34 -03:00
melissalkelly
6f0cfb5ace AAP-62657 Implement logic to extract and populate JWT claims from Controller Jobs (#16259)
* AAP-62657 Add populate_claims_for_workload function and unit tests

* Update safe_get helper function

* Trigger CI rebuild to pick up latest django-ansible-base

* Trigger CI after org visibility update

* Retrigger CI

* Rename workload to job, refine safe_get helper function

* Update test_jobs to use job fixture

* Retrigger CI

* Create fresh job, removed launched_by since this is read-only property

* Retrigger CI after runner issues

* Retrigger CI after runner issues

* Add unit tests for other workload types

* Update CLAIM_LAUNCHED_BY_USER_NAME and CLAIM_LAUNCHED_BY_USER_ID, with CLAIM_LAUNCHED_BY_NAME and CLAIM_LAUNCHED_BY_ID

* Generate claims with a more static schema

try to operate directly on object when possible

For cases where field is valid for the type, but null value
  still add the field, so blank and null values appear

* Allow unified related items to be omittied

---------

Co-authored-by: AlanCoding <arominge@redhat.com>
2026-02-09 20:58:49 +00:00
🇺🇦 Sviatoslav Sydorenko (Святослав Сидоренко)
fc0a4cddce 🧪 Use the unified test reporting action (#16168) 2026-02-09 09:54:04 -05:00
Seth Foster
99511efe81 bump pyasn1 (#16249)
Bump pyasn1 for CVE-2026-2349

Signed-off-by: Seth Foster <fosterbseth@gmail.com>
2026-02-05 14:12:59 -05:00
Rodrigo Toshiaki Horie
30bf910bd5 fix schema generator (#16265) 2026-02-04 20:54:28 -03:00
jessicamack
c9085e4b7f Update OpenAPI spec to improve descriptions and messages (#16260)
* Update OpenAPI spec

* lint fixes

* fix decorator for retrieve endpoints

* change decorator method

* fix import

* lint fix
2026-02-04 22:32:57 +00:00
Alan Rominger
5e93f60b9e AAP-41776 Enable new fancy asyncio metrics for dispatcherd (#16233)
* Enable new fancy asyncio metrics for dispatcherd

Remove old dispatcher metrics and patch in new data from local whatever

Update test fixture to new dispatcherd version

* Update dispatcherd again

* Handle node filter in URL, and catch more errors

* Add test for metric filter

* Split module for dispatcherd metrics
2026-02-04 15:28:34 -05:00
Rodrigo Toshiaki Horie
6a031158ce Fix OpenAPI schema enum values for CredentialType kind field (#16262)
The OpenAPI schema incorrectly showed all 12 credential type kinds as
valid for POST/PUT/PATCH operations, when only 'cloud' and 'net' are
allowed for custom credential types. This caused API clients and LLM
agents to receive HTTP 400 errors when attempting to create credential
types with invalid kind values.

Add postprocessing hook to filter CredentialTypeRequest and
PatchedCredentialTypeRequest schemas to only show 'cloud', 'net',
and null as valid enum values, matching the existing validation logic.

No API behavior changes - this is purely a documentation fix.

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-04 16:39:03 -03:00
joeywashburn
749735b941 Standardize spelling of 'canceled' in wsrelay.py (#16178)
Changed two instances of 'cancelled' to 'canceled' in awx/main/wsrelay.py
to match AWX's standardized American English spelling convention.

- Updated log message in WebsocketRelayConnection.connect()
- Updated comment in WebSocketRelayManager.cleanup_offline_host()

Fixes #15177

Signed-off-by: Joey Washburn <joey@joeywashburn.com>
2026-02-04 12:29:45 -05:00
Chris Meyers
315f9c7eef Rename args var
* https://sonarcloud.io/project/issues?open=AZDmRbV12PiUXMD3dYmh&id=ansible_awx
2026-02-04 08:17:51 -05:00
Chris Meyers
00c0f7e8db add test 2026-02-03 16:12:22 -05:00
Chris Meyers
37ccbc28bd Harden log message output containing user input
* base64 encode user inputed url when logging so that newlines or other
  malicious payloads can't be injected into the log stream
2026-02-03 16:12:22 -05:00
Chris Meyers
63fafec76f Remove init return value
* https://sonarcloud.io/project/issues?open=AZDmRaaJ2PiUXMD3dXly&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
cba01339a1 Remove redunant role attribute
* https://sonarcloud.io/project/issues?open=AZDmRbYN2PiUXMD3dYo5&id=ansible_awx&tab=code
2026-02-03 16:12:00 -05:00
Chris Meyers
2622e9d295 Add alt text for awx logo
* https://sonarcloud.io/project/issues?open=AZDmRbYR2PiUXMD3dYo6&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
a6afec6ebb Add generic font family
* https://sonarcloud.io/project/issues?open=AZDmRbX42PiUXMD3dYo1&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
f406a377f7 Add generic font family
* https://sonarcloud.io/project/issues?open=AZDmRbX42PiUXMD3dYo0&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
adc3e35978 Add generic font family
* remove quotes so that it's treated as a generic font family
* https://sonarcloud.io/project/issues?open=AZDmRbX42PiUXMD3dYoz&id=ansible_awx&tab=code
2026-02-03 16:12:00 -05:00
Chris Meyers
838e67005c Remove duplicate css property
* Last one wins so remove the first one.
* https://sonarcloud.io/project/issues?open=AZpWSq7yO74rjWmAOcwf&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
e13fcfe29f Add alt text to 504 image
* https://sonarcloud.io/project/issues?open=AZDmRbXt2PiUXMD3dYor&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
0f4e91419a Add lang english tag to 504 page
* https://sonarcloud.io/project/issues?open=AZDmRbXt2PiUXMD3dYop&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
cca70b242a Add alt text to 502 image
* https://sonarcloud.io/project/issues?open=AZDmRbXx2PiUXMD3dYou&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
edf459f8ec Add language english to 502
* https://sonarcloud.io/project/issues?open=AZDmRbXx2PiUXMD3dYos&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
f4286216d6 Add doctype and lang
* https://sonarcloud.io/project/issues?open=AZDmRbX02PiUXMD3dYov&id=ansible_awx
* https://sonarcloud.io/project/issues?open=AZDmRbX02PiUXMD3dYow&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
0ab1fea731 Use replaceAll() for global regex
* https://sonarcloud.io/project/issues?open=AZlyqQeaRtfhwxlTsfP0&id=ansible_awx
* https://sonarcloud.io/project/issues?open=AZlyqQeaRtfhwxlTsfP1&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
e3ac581fdf Always use a tz aware timestamp
* https://sonarcloud.io/project/issues?open=AZDmRade2PiUXMD3dXnx&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
5aa3e8cf3b Make tz aware
* Note, this doesn't change the logic or output, but maybe makes
  someones life easier later.
* https://sonarcloud.io/project/issues?open=AZDmRade2PiUXMD3dXnx&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
8289003c0d Remove unreachable code path
* https://sonarcloud.io/project/issues?open=AZDmRaXd2PiUXMD3dXkN&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
125083538a Compare float to float
* https://sonarcloud.io/project/issues?open=AZDmRaXd2PiUXMD3dXkF&id=ansible_awx&tab=code
2026-02-03 16:12:00 -05:00
Chris Meyers
ed5ab8becd Remove unused variable
* https://sonarcloud.io/project/issues?open=AZL9AcQZ0bcYsK7qDrTo&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
fc0087f1b2 Add language to api stdout for translation helping
* https://sonarcloud.io/project/issues?open=AZDmRbV82PiUXMD3dYmx&id=ansible_awx&tab=code
2026-02-03 16:12:00 -05:00
Chris Meyers
cfc5ad9d91 Remove return value from __init__
* https://sonarcloud.io/project/issues?open=AZDmRaZX2PiUXMD3dXle&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
d929b767b6 Rename kwargs
* https://sonarcloud.io/project/issues?open=AZDmRbVW2PiUXMD3dYmW&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
5f434ac348 Rename exception args variable
* https://sonarcloud.io/project/issues?open=AZDmRbV12PiUXMD3dYmg&id=ansible_awx
* https://sonarcloud.io/project/issues?open=AZDmRaZX2PiUXMD3dXle&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
4de9c8356b Use fromkeys for constant
* https://sonarcloud.io/project/issues?open=AZeD0GsJyrLLb-kZUOoF&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
91118adbd3 Fix summary_dict None check
* https://sonarcloud.io/project/issues?open=AZDmRbWy2PiUXMD3dYoJ&id=ansible_awx
2026-02-03 16:12:00 -05:00
Chris Meyers
25f538277a Fix init return
* __init__ shouldn't return
2026-02-03 16:12:00 -05:00
joeywashburn
82cb52d648 Sanitize SSH key whitespace to prevent validation errors (#16179)
Strip leading and trailing whitespace from SSH keys in validate_ssh_private_key()
to handle common copy-paste scenarios where hidden newlines cause base64 decoding
failures.

Changes:
- Added data.strip() in validate_ssh_private_key() before calling validate_pem()
- Added test_ssh_key_with_whitespace() to verify keys with leading/trailing
  newlines are properly sanitized and validated

This prevents the confusing "HTTP 500: Internal Server Error" and
"binascii.Error: Incorrect padding" errors when users paste SSH keys with
accidental whitespace.

Fixes #14219

Signed-off-by: Joey Washburn <joey@joeywashburn.com>
2026-02-02 11:16:28 -05:00
Peter Braun
f7958b93bd add deprecated fields to x-ai-description for credential post (#16255) 2026-01-29 18:17:31 +01:00
Alan Rominger
3d68ca848e Fix race condition of un-expired cache in local workers (#16256) 2026-01-29 11:31:06 -05:00
Adrià Sala
99dce79078 fix: add py311 to make version detection 2026-01-28 14:20:48 +01:00
Alan Rominger
271383d018 AAP-60470 Add dispatcherctl and dispatcherd commands as updated interface to dispatcherd lib (#16206)
* Add dispatcherctl command

* Add tests for dispatcherctl command

* Exit early if sqlite3

* Switch to dispatcherd mgmt cmd

* Move unwanted command options to run_dispatcher

* Add test for new stuff

* Update the SOS report status command

* make docs always reference new command

* Consistently error if given config file
2026-01-27 15:57:23 -05:00
Alan Rominger
1128ad5a57 AAP-64221 Fix broken cancel logic with dispatcherd (#16247)
* Fix broken cancel logic with dispatcherd

Update tests for UnifiedJob

Update test assertion

* Further simply cancel path
2026-01-27 14:39:08 -05:00
Seth Foster
823b736afe Remove unused INSIGHTS_OIDC_ENDPOINT (#16235)
This setting is set in defaults.py, but
currently not being used. More technically,
project_update.yml is not passing this value to
the insights.py action plugin. Therefore, we
can safely remove references to it.

insights.py already has a default oidc endpoint
defined for authentication.

Signed-off-by: Seth Foster <fosterbseth@gmail.com>
2026-01-27 10:13:37 -05:00
Alan Rominger
f80bbc57d8 AAP-43117 Additional dispatcher removal simplifications and waiting reaper updates (#16243)
* Additional dispatcher removal simplifications and waiting repear updates

* Fix double call and logging message

* Implement bugbot comment, should reap running on lost instances

* Add test case for new pending behavior
2026-01-26 13:55:37 -05:00
TVo
12a7229ee9 Publish open api spec on AWX for community use (#16221)
* Added link and ref to openAPI spec for community

* Update docs/docsite/rst/contributor/openapi_link.rst

Co-authored-by: Don Naro <dnaro@redhat.com>

* add sphinxcontrib-redoc to requirements

* sphinxcontrib.redoc configuration

* create openapi directory and files

* update download script for both schema files

* suppress warning for redoc

* update labels

* fix extra closing parenthesis

* update schema url

* exclude doc config and download script

The Sphinx configuration (conf.py) and schema download script
(download-json.py) are not application logic and used only for building
documentation. Coverage requirements for these files are overkill.

* exclude only the sphinx config file

---------

Co-authored-by: Don Naro <dnaro@redhat.com>
2026-01-26 18:16:49 +00:00
Don Naro
ceed692354 change contact email address (#16245)
Use the ansible-community@redhat.com alias as the contact email address.
2026-01-26 17:28:00 +00:00
Jake Jackson
36a00ec46b AAP-58539 Move to dispatcherd (#16209)
* WIP First pass
* started removing feature flags and adjusting logic
* Add decorator
* moved to dispatcher decorator
* updated as many as I could find
* Keep callback receiver working
* remove any code that is not used by the call back receiver
* add back auto_max_workers
* added back get_auto_max_workers into common utils
* Remove control and hazmat (squash this not done)
* moved status out and deleted control as no longer needed
* removed unused imports
* adjusted test import to pull correct method
* fixed imports and addressed clusternode heartbeat test
* Update function comments
* Add back hazmat for config and remove baseworker
* added back hazmat per @alancoding feedback around config
* removed baseworker completely and refactored it into the callback
  worker
* Fix dispatcher run call and remove dispatch setting
* remove dispatcher mock publish setting
* Adjust heartbeat arg and more formatting
* fixed the call to cluster_node_heartbeat missing binder
* Fix attribute error in server logs
2026-01-23 20:49:32 +00:00
119 changed files with 2541 additions and 2405 deletions

View File

@@ -45,15 +45,45 @@ jobs:
make docker-runner 2>&1 | tee schema-diff.txt make docker-runner 2>&1 | tee schema-diff.txt
exit ${PIPESTATUS[0]} exit ${PIPESTATUS[0]}
- name: Add schema diff to job summary - name: Validate OpenAPI schema
id: schema-validation
continue-on-error: true
run: |
AWX_DOCKER_ARGS='-e GITHUB_ACTIONS' \
AWX_DOCKER_CMD='make validate-openapi-schema' \
make docker-runner 2>&1 | tee schema-validation.txt
exit ${PIPESTATUS[0]}
- name: Add schema validation and diff to job summary
if: always() if: always()
# show text and if for some reason, it can't be generated, state that it can't be. # show text and if for some reason, it can't be generated, state that it can't be.
run: | run: |
echo "## API Schema Change Detection Results" >> $GITHUB_STEP_SUMMARY echo "## API Schema Check Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
# Show validation status
echo "### OpenAPI Validation" >> $GITHUB_STEP_SUMMARY
if [ -f schema-validation.txt ] && grep -q "✓ Schema is valid" schema-validation.txt; then
echo "✅ **Status:** PASSED - Schema is valid OpenAPI 3.0.3" >> $GITHUB_STEP_SUMMARY
else
echo "❌ **Status:** FAILED - Schema validation failed" >> $GITHUB_STEP_SUMMARY
if [ -f schema-validation.txt ]; then
echo "" >> $GITHUB_STEP_SUMMARY
echo "<details><summary>Validation errors</summary>" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
cat schema-validation.txt >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi
fi
echo "" >> $GITHUB_STEP_SUMMARY
# Show schema changes
echo "### Schema Changes" >> $GITHUB_STEP_SUMMARY
if [ -f schema-diff.txt ]; then if [ -f schema-diff.txt ]; then
if grep -q "^+" schema-diff.txt || grep -q "^-" schema-diff.txt; then if grep -q "^+" schema-diff.txt || grep -q "^-" schema-diff.txt; then
echo "### Schema changes detected" >> $GITHUB_STEP_SUMMARY echo "**Changes detected** between this PR and the base branch" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
# Truncate to first 1000 lines to stay under GitHub's 1MB summary limit # Truncate to first 1000 lines to stay under GitHub's 1MB summary limit
TOTAL_LINES=$(wc -l < schema-diff.txt) TOTAL_LINES=$(wc -l < schema-diff.txt)
@@ -65,8 +95,8 @@ jobs:
head -n 1000 schema-diff.txt >> $GITHUB_STEP_SUMMARY head -n 1000 schema-diff.txt >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY
else else
echo "### No schema changes detected" >> $GITHUB_STEP_SUMMARY echo "No schema changes detected" >> $GITHUB_STEP_SUMMARY
fi fi
else else
echo "### Unable to generate schema diff" >> $GITHUB_STEP_SUMMARY echo "Unable to generate schema diff" >> $GITHUB_STEP_SUMMARY
fi fi

View File

@@ -112,25 +112,27 @@ jobs:
path: reports/coverage.xml path: reports/coverage.xml
retention-days: 5 retention-days: 5
- name: Upload awx jUnit test reports - name: >-
Upload ${{
matrix.tests.coverage-upload-name || 'awx'
}} jUnit test reports to the unified dashboard
if: >- if: >-
!cancelled() !cancelled()
&& steps.make-run.outputs.test-result-files != '' && steps.make-run.outputs.test-result-files != ''
&& github.event_name == 'push' && github.event_name == 'push'
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id && env.UPSTREAM_REPOSITORY_ID == github.repository_id
&& github.ref_name == github.event.repository.default_branch && github.ref_name == github.event.repository.default_branch
run: | uses: ansible/gh-action-record-test-results@cd5956ead39ec66351d0779470c8cff9638dd2b8
for junit_file in $(echo '${{ steps.make-run.outputs.test-result-files }}' | sed 's/,/ /') with:
do aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}
curl \ http-auth-password: >-
-v \ ${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }}
--user "${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }}:${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }}" \ http-auth-username: >-
--form "xunit_xml=@${junit_file}" \ ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }}
--form "component_name=${{ matrix.tests.coverage-upload-name || 'awx' }}" \ project-component-name: >-
--form "git_commit_sha=${{ github.sha }}" \ ${{ matrix.tests.coverage-upload-name || 'awx' }}
--form "git_repository_url=https://github.com/${{ github.repository }}" \ test-result-files: >-
"${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}/api/results/upload/" ${{ steps.make-run.outputs.test-result-files }}
done
dev-env: dev-env:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -212,7 +214,7 @@ jobs:
continue-on-error: true continue-on-error: true
run: | run: |
set +e set +e
timeout 54m bash -elc ' timeout 15m bash -elc '
python -m pip install -r molecule/requirements.txt python -m pip install -r molecule/requirements.txt
python -m pip install PyYAML # for awx/tools/scripts/rewrite-awx-operator-requirements.py python -m pip install PyYAML # for awx/tools/scripts/rewrite-awx-operator-requirements.py
$(realpath ../awx/tools/scripts/rewrite-awx-operator-requirements.py) molecule/requirements.yml $(realpath ../awx) $(realpath ../awx/tools/scripts/rewrite-awx-operator-requirements.py) molecule/requirements.yml $(realpath ../awx)
@@ -294,18 +296,16 @@ jobs:
&& github.event_name == 'push' && github.event_name == 'push'
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id && env.UPSTREAM_REPOSITORY_ID == github.repository_id
&& github.ref_name == github.event.repository.default_branch && github.ref_name == github.event.repository.default_branch
run: | uses: ansible/gh-action-record-test-results@cd5956ead39ec66351d0779470c8cff9638dd2b8
for junit_file in $(echo '${{ steps.make-run.outputs.test-result-files }}' | sed 's/,/ /') with:
do aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}
curl \ http-auth-password: >-
-v \ ${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }}
--user "${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }}:${{ secrets.PDE_ORG_RESULTS_UPLOAD_PASSWORD }}" \ http-auth-username: >-
--form "xunit_xml=@${junit_file}" \ ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_USER }}
--form "component_name=awx" \ project-component-name: awx
--form "git_commit_sha=${{ github.sha }}" \ test-result-files: >-
--form "git_repository_url=https://github.com/${{ github.repository }}" \ ${{ steps.make-run.outputs.test-result-files }}
"${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}/api/results/upload/"
done
collection-integration: collection-integration:
name: awx_collection integration name: awx_collection integration

176
.github/workflows/spec-sync-on-merge.yml vendored Normal file
View File

@@ -0,0 +1,176 @@
# Sync OpenAPI Spec on Merge
#
# This workflow runs when code is merged to the devel branch.
# It runs the dev environment to generate the OpenAPI spec, then syncs it to
# the central spec repository.
#
# FLOW: PR merged → push to branch → dev environment runs → spec synced to central repo
#
# NOTE: This is an inlined version for testing with private forks.
# Production version will use a reusable workflow from the org repos.
name: Sync OpenAPI Spec on Merge
env:
LC_ALL: "C.UTF-8"
DEV_DOCKER_OWNER: ${{ github.repository_owner }}
on:
push:
branches:
- devel
workflow_dispatch: # Allow manual triggering for testing
jobs:
sync-openapi-spec:
name: Sync OpenAPI spec to central repo
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Checkout Controller repository
uses: actions/checkout@v4
with:
show-progress: false
- name: Build awx_devel image to use for schema gen
uses: ./.github/actions/awx_devel_image
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
private-github-key: ${{ secrets.PRIVATE_GITHUB_KEY }}
- name: Generate API Schema
run: |
DEV_DOCKER_TAG_BASE=ghcr.io/${OWNER_LC} \
COMPOSE_TAG=${{ github.base_ref || github.ref_name }} \
docker run -u $(id -u) --rm -v ${{ github.workspace }}:/awx_devel/:Z \
--workdir=/awx_devel `make print-DEVEL_IMAGE_NAME` /start_tests.sh genschema
- name: Verify spec file exists
run: |
SPEC_FILE="./schema.json"
if [ ! -f "$SPEC_FILE" ]; then
echo "❌ Spec file not found at $SPEC_FILE"
echo "Contents of workspace:"
ls -la .
exit 1
fi
echo "✅ Found spec file at $SPEC_FILE"
- name: Checkout spec repo
id: checkout_spec_repo
continue-on-error: true
uses: actions/checkout@v4
with:
repository: ansible-automation-platform/aap-openapi-specs
ref: ${{ github.ref_name }}
path: spec-repo
token: ${{ secrets.OPENAPI_SPEC_SYNC_TOKEN }}
- name: Fail if branch doesn't exist
if: steps.checkout_spec_repo.outcome == 'failure'
run: |
echo "##[error]❌ Branch '${{ github.ref_name }}' does not exist in the central spec repository."
echo "##[error]Expected branch: ${{ github.ref_name }}"
echo "##[error]This branch must be created in the spec repo before specs can be synced."
exit 1
- name: Compare specs
id: compare
run: |
COMPONENT_SPEC="./schema.json"
SPEC_REPO_FILE="spec-repo/controller.json"
# Check if spec file exists in spec repo
if [ ! -f "$SPEC_REPO_FILE" ]; then
echo "Spec file doesn't exist in spec repo - will create new file"
echo "has_diff=true" >> $GITHUB_OUTPUT
echo "is_new_file=true" >> $GITHUB_OUTPUT
else
# Compare files
if diff -q "$COMPONENT_SPEC" "$SPEC_REPO_FILE" > /dev/null; then
echo "✅ No differences found - specs are identical"
echo "has_diff=false" >> $GITHUB_OUTPUT
else
echo "📝 Differences found - spec has changed"
echo "has_diff=true" >> $GITHUB_OUTPUT
echo "is_new_file=false" >> $GITHUB_OUTPUT
fi
fi
- name: Update spec file
if: steps.compare.outputs.has_diff == 'true'
run: |
cp "./schema.json" "spec-repo/controller.json"
echo "✅ Updated spec-repo/controller.json"
- name: Create PR in spec repo
if: steps.compare.outputs.has_diff == 'true'
working-directory: spec-repo
env:
GH_TOKEN: ${{ secrets.OPENAPI_SPEC_SYNC_TOKEN }}
COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
run: |
# Configure git
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Create branch for PR
SHORT_SHA="${{ github.sha }}"
SHORT_SHA="${SHORT_SHA:0:7}"
BRANCH_NAME="update-Controller-${{ github.ref_name }}-${SHORT_SHA}"
git checkout -b "$BRANCH_NAME"
# Add and commit changes
git add "controller.json"
if [ "${{ steps.compare.outputs.is_new_file }}" == "true" ]; then
COMMIT_MSG="Add Controller OpenAPI spec for ${{ github.ref_name }}"
else
COMMIT_MSG="Update Controller OpenAPI spec for ${{ github.ref_name }}"
fi
git commit -m "$COMMIT_MSG
Synced from ${{ github.repository }}@${{ github.sha }}
Source branch: ${{ github.ref_name }}
Co-Authored-By: github-actions[bot] <github-actions[bot]@users.noreply.github.com>"
# Push branch
git push origin "$BRANCH_NAME"
# Create PR
PR_TITLE="[${{ github.ref_name }}] Update Controller spec from merged commit"
PR_BODY="## Summary
Automated OpenAPI spec sync from component repository merge.
**Source:** ${{ github.repository }}@${{ github.sha }}
**Branch:** \`${{ github.ref_name }}\`
**Component:** \`Controller\`
**Spec File:** \`controller.json\`
## Changes
$(if [ "${{ steps.compare.outputs.is_new_file }}" == "true" ]; then echo "- 🆕 New spec file created"; else echo "- 📝 Spec file updated with latest changes"; fi)
## Source Commit
\`\`\`
${COMMIT_MESSAGE}
\`\`\`
---
🤖 This PR was automatically generated by the OpenAPI spec sync workflow."
gh pr create \
--title "$PR_TITLE" \
--body "$PR_BODY" \
--base "${{ github.ref_name }}" \
--head "$BRANCH_NAME"
echo "✅ Created PR in spec repo"
- name: Report results
if: always()
run: |
if [ "${{ steps.compare.outputs.has_diff }}" == "true" ]; then
echo "📝 Spec sync completed - PR created in spec repo"
else
echo "✅ Spec sync completed - no changes needed"
fi

View File

@@ -1,6 +1,6 @@
-include awx/ui/Makefile -include awx/ui/Makefile
PYTHON := $(notdir $(shell for i in python3.12 python3; do command -v $$i; done|sed 1q)) PYTHON := $(notdir $(shell for i in python3.12 python3.11 python3; do command -v $$i; done|sed 1q))
SHELL := bash SHELL := bash
DOCKER_COMPOSE ?= docker compose DOCKER_COMPOSE ?= docker compose
OFFICIAL ?= no OFFICIAL ?= no
@@ -79,7 +79,7 @@ RECEPTOR_IMAGE ?= quay.io/ansible/receptor:devel
SRC_ONLY_PKGS ?= cffi,pycparser,psycopg,twilio SRC_ONLY_PKGS ?= cffi,pycparser,psycopg,twilio
# These should be upgraded in the AWX and Ansible venv before attempting # These should be upgraded in the AWX and Ansible venv before attempting
# to install the actual requirements # to install the actual requirements
VENV_BOOTSTRAP ?= pip==25.3 setuptools==80.9.0 setuptools_scm[toml]==9.2.2 wheel==0.45.1 cython==3.1.3 VENV_BOOTSTRAP ?= pip==25.3 setuptools==80.9.0 setuptools_scm[toml]==9.2.2 wheel==0.46.3 cython==3.1.3
NAME ?= awx NAME ?= awx
@@ -289,7 +289,7 @@ dispatcher:
@if [ "$(VENV_BASE)" ]; then \ @if [ "$(VENV_BASE)" ]; then \
. $(VENV_BASE)/awx/bin/activate; \ . $(VENV_BASE)/awx/bin/activate; \
fi; \ fi; \
$(PYTHON) manage.py run_dispatcher $(PYTHON) manage.py dispatcherd
## Run to start the zeromq callback receiver ## Run to start the zeromq callback receiver
receiver: receiver:
@@ -579,6 +579,10 @@ detect-schema-change: genschema
# diff exits with 1 when files differ - capture but don't fail # diff exits with 1 when files differ - capture but don't fail
-diff -u -b reference-schema.json schema.json -diff -u -b reference-schema.json schema.json
validate-openapi-schema: genschema
@echo "Validating OpenAPI schema from schema.json..."
@python3 -c "from openapi_spec_validator import validate; import json; spec = json.load(open('schema.json')); validate(spec); print('✓ OpenAPI Schema is valid!')"
docker-compose-clean: awx/projects docker-compose-clean: awx/projects
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml rm -sf $(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml rm -sf

View File

@@ -89,7 +89,7 @@ class DeprecatedCredentialField(serializers.IntegerField):
def to_internal_value(self, pk): def to_internal_value(self, pk):
try: try:
pk = int(pk) pk = int(pk)
except ValueError: except (ValueError, TypeError):
self.fail('invalid') self.fail('invalid')
try: try:
Credential.objects.get(pk=pk) Credential.objects.get(pk=pk)

View File

@@ -111,7 +111,7 @@ class UnifiedJobEventPagination(Pagination):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.use_limit_paginator = False self.use_limit_paginator = False
self.limit_pagination = LimitPagination() self.limit_pagination = LimitPagination()
return super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def paginate_queryset(self, queryset, request, view=None): def paginate_queryset(self, queryset, request, view=None):
if 'limit' in request.query_params: if 'limit' in request.query_params:

View File

@@ -9,6 +9,50 @@ from drf_spectacular.views import (
) )
def filter_credential_type_schema(
result,
generator, # NOSONAR
request, # NOSONAR
public, # NOSONAR
):
"""
Postprocessing hook to filter CredentialType kind enum values.
For CredentialTypeRequest and PatchedCredentialTypeRequest schemas (POST/PUT/PATCH),
filter the 'kind' enum to only show 'cloud' and 'net' values.
This ensures the OpenAPI schema accurately reflects that only 'cloud' and 'net'
credential types can be created or modified via the API, matching the validation
in CredentialTypeSerializer.validate().
Args:
result: The OpenAPI schema dict to be modified
generator, request, public: Required by drf-spectacular interface (unused)
Returns:
The modified OpenAPI schema dict
"""
schemas = result.get('components', {}).get('schemas', {})
# Filter CredentialTypeRequest (POST/PUT) - field is required
if 'CredentialTypeRequest' in schemas:
kind_prop = schemas['CredentialTypeRequest'].get('properties', {}).get('kind', {})
if 'enum' in kind_prop:
# Filter to only cloud and net (no None - field is required)
kind_prop['enum'] = ['cloud', 'net']
kind_prop['description'] = "* `cloud` - Cloud\\n* `net` - Network"
# Filter PatchedCredentialTypeRequest (PATCH) - field is optional
if 'PatchedCredentialTypeRequest' in schemas:
kind_prop = schemas['PatchedCredentialTypeRequest'].get('properties', {}).get('kind', {})
if 'enum' in kind_prop:
# Filter to only cloud and net (None allowed - field can be omitted in PATCH)
kind_prop['enum'] = ['cloud', 'net', None]
kind_prop['description'] = "* `cloud` - Cloud\\n* `net` - Network"
return result
class CustomAutoSchema(AutoSchema): class CustomAutoSchema(AutoSchema):
"""Custom AutoSchema to add swagger_topic to tags and handle deprecated endpoints.""" """Custom AutoSchema to add swagger_topic to tags and handle deprecated endpoints."""

View File

@@ -1230,7 +1230,7 @@ class OrganizationSerializer(BaseSerializer, OpaQueryPathMixin):
# to a team. This provides a hint to the ui so it can know to not # to a team. This provides a hint to the ui so it can know to not
# display these roles for team role selection. # display these roles for team role selection.
for key in ('admin_role', 'member_role'): for key in ('admin_role', 'member_role'):
if key in summary_dict.get('object_roles', {}): if summary_dict and key in summary_dict.get('object_roles', {}):
summary_dict['object_roles'][key]['user_only'] = True summary_dict['object_roles'][key]['user_only'] = True
return summary_dict return summary_dict
@@ -2165,13 +2165,13 @@ class BulkHostDeleteSerializer(serializers.Serializer):
attrs['hosts_data'] = attrs['host_qs'].values() attrs['hosts_data'] = attrs['host_qs'].values()
if len(attrs['host_qs']) == 0: if len(attrs['host_qs']) == 0:
error_hosts = {host: "Hosts do not exist or you lack permission to delete it" for host in attrs['hosts']} error_hosts = dict.fromkeys(attrs['hosts'], "Hosts do not exist or you lack permission to delete it")
raise serializers.ValidationError({'hosts': error_hosts}) raise serializers.ValidationError({'hosts': error_hosts})
if len(attrs['host_qs']) < len(attrs['hosts']): if len(attrs['host_qs']) < len(attrs['hosts']):
hosts_exists = [host['id'] for host in attrs['hosts_data']] hosts_exists = [host['id'] for host in attrs['hosts_data']]
failed_hosts = list(set(attrs['hosts']).difference(hosts_exists)) failed_hosts = list(set(attrs['hosts']).difference(hosts_exists))
error_hosts = {host: "Hosts do not exist or you lack permission to delete it" for host in failed_hosts} error_hosts = dict.fromkeys(failed_hosts, "Hosts do not exist or you lack permission to delete it")
raise serializers.ValidationError({'hosts': error_hosts}) raise serializers.ValidationError({'hosts': error_hosts})
# Getting all inventories that the hosts can be in # Getting all inventories that the hosts can be in
@@ -3527,7 +3527,7 @@ class JobRelaunchSerializer(BaseSerializer):
choices=NEW_JOB_TYPE_CHOICES, choices=NEW_JOB_TYPE_CHOICES,
write_only=True, write_only=True,
) )
credential_passwords = VerbatimField(required=True, write_only=True) credential_passwords = VerbatimField(required=False, write_only=True)
class Meta: class Meta:
model = Job model = Job

View File

@@ -1,6 +1,6 @@
{% if content_only %}<div class="nocode ansi_fore ansi_back{% if dark %} ansi_dark{% endif %}">{% else %} {% if content_only %}<div class="nocode ansi_fore ansi_back{% if dark %} ansi_dark{% endif %}">{% else %}
<!DOCTYPE HTML> <!DOCTYPE HTML>
<html> <html lang="en">
<head> <head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"> <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<title>{{ title }}</title> <title>{{ title }}</title>

View File

@@ -52,9 +52,9 @@ from ansi2html import Ansi2HTMLConverter
from datetime import timezone as dt_timezone from datetime import timezone as dt_timezone
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
from drf_spectacular.utils import extend_schema_view, extend_schema
# django-ansible-base # django-ansible-base
from ansible_base.lib.utils.requests import get_remote_hosts
from ansible_base.rbac.models import RoleEvaluation from ansible_base.rbac.models import RoleEvaluation
from ansible_base.lib.utils.schema import extend_schema_if_available from ansible_base.lib.utils.schema import extend_schema_if_available
@@ -97,6 +97,7 @@ from awx.main.utils import (
from awx.main.utils.encryption import encrypt_value from awx.main.utils.encryption import encrypt_value
from awx.main.utils.filters import SmartFilter from awx.main.utils.filters import SmartFilter
from awx.main.utils.plugins import compute_cloud_inventory_sources from awx.main.utils.plugins import compute_cloud_inventory_sources
from awx.main.utils.proxy import get_first_remote_host_from_headers
from awx.main.utils.common import memoize from awx.main.utils.common import memoize
from awx.main.redact import UriCleaner from awx.main.redact import UriCleaner
from awx.api.permissions import ( from awx.api.permissions import (
@@ -378,6 +379,10 @@ class DashboardJobsGraphView(APIView):
class InstanceList(ListCreateAPIView): class InstanceList(ListCreateAPIView):
"""
Creates an instance if used on a Kubernetes or OpenShift deployment of Ansible Automation Platform.
"""
name = _("Instances") name = _("Instances")
model = models.Instance model = models.Instance
serializer_class = serializers.InstanceSerializer serializer_class = serializers.InstanceSerializer
@@ -1454,7 +1459,7 @@ class CredentialList(ListCreateAPIView):
@extend_schema_if_available( @extend_schema_if_available(
extensions={ extensions={
"x-ai-description": "Create a new credential. The `inputs` field contain type-specific input fields. The required fields depend on related `credential_type`. Use GET /v2/credential_types/{id}/ (tool name: controller.credential_types_retrieve) and inspect `inputs` field for the specific credential type's expected schema." "x-ai-description": "Create a new credential. The `inputs` field contain type-specific input fields. The required fields depend on related `credential_type`. Use GET /v2/credential_types/{id}/ (tool name: controller.credential_types_retrieve) and inspect `inputs` field for the specific credential type's expected schema. The fields `user` and `team` are deprecated and should not be included in the payload."
} }
) )
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
@@ -1603,7 +1608,11 @@ class CredentialExternalTest(SubDetailAPIView):
obj_permission_type = 'use' obj_permission_type = 'use'
resource_purpose = 'test external credential' resource_purpose = 'test external credential'
@extend_schema_if_available(extensions={"x-ai-description": "Test update the input values and metadata of an external credential"}) @extend_schema_if_available(extensions={"x-ai-description": """Test update the input values and metadata of an external credential.
This endpoint supports testing credentials that connect to external secret management systems
such as CyberArk AIM, CyberArk Conjur, HashiCorp Vault, AWS Secrets Manager, Azure Key Vault,
Centrify Vault, Thycotic DevOps Secrets Vault, and GitHub App Installation Access Token Lookup.
It does not support standard credential types such as Machine, SCM, and Cloud."""})
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
obj = self.get_object() obj = self.get_object()
backend_kwargs = {} backend_kwargs = {}
@@ -1617,13 +1626,16 @@ class CredentialExternalTest(SubDetailAPIView):
with set_environ(**settings.AWX_TASK_ENV): with set_environ(**settings.AWX_TASK_ENV):
obj.credential_type.plugin.backend(**backend_kwargs) obj.credential_type.plugin.backend(**backend_kwargs)
return Response({}, status=status.HTTP_202_ACCEPTED) return Response({}, status=status.HTTP_202_ACCEPTED)
except requests.exceptions.HTTPError as exc: except requests.exceptions.HTTPError:
message = 'HTTP {}'.format(exc.response.status_code) message = """Test operation is not supported for credential type {}.
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) This endpoint only supports credentials that connect to
external secret management systems such as CyberArk, HashiCorp
Vault, or cloud-based secret managers.""".format(obj.credential_type.kind)
return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc: except Exception as exc:
message = exc.__class__.__name__ message = exc.__class__.__name__
args = getattr(exc, 'args', []) exc_args = getattr(exc, 'args', [])
for a in args: for a in exc_args:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason) message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
@@ -1681,8 +1693,8 @@ class CredentialTypeExternalTest(SubDetailAPIView):
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc: except Exception as exc:
message = exc.__class__.__name__ message = exc.__class__.__name__
args = getattr(exc, 'args', []) args_exc = getattr(exc, 'args', [])
for a in args: for a in args_exc:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason) message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
@@ -2469,6 +2481,11 @@ class JobTemplateDetail(RelatedJobsPreventDeleteMixin, RetrieveUpdateDestroyAPIV
resource_purpose = 'job template detail' resource_purpose = 'job template detail'
@extend_schema_view(
retrieve=extend_schema(
extensions={'x-ai-description': 'List job template launch criteria'},
)
)
class JobTemplateLaunch(RetrieveAPIView): class JobTemplateLaunch(RetrieveAPIView):
model = models.JobTemplate model = models.JobTemplate
obj_permission_type = 'start' obj_permission_type = 'start'
@@ -2477,6 +2494,9 @@ class JobTemplateLaunch(RetrieveAPIView):
resource_purpose = 'launch a job from a job template' resource_purpose = 'launch a job from a job template'
def update_raw_data(self, data): def update_raw_data(self, data):
"""
Use the ID of a job template to retrieve its launch details.
"""
try: try:
obj = self.get_object() obj = self.get_object()
except PermissionDenied: except PermissionDenied:
@@ -2857,7 +2877,8 @@ class JobTemplateCallback(GenericAPIView):
host for the current request. host for the current request.
""" """
# Find the list of remote host names/IPs to check. # Find the list of remote host names/IPs to check.
remote_hosts = set(get_remote_hosts(self.request)) # Only consider the first entry from each header (for comma-separated values like X-Forwarded-For)
remote_hosts = get_first_remote_host_from_headers(self.request, settings.REMOTE_HOST_HEADERS)
# Add the reverse lookup of IP addresses. # Add the reverse lookup of IP addresses.
for rh in list(remote_hosts): for rh in list(remote_hosts):
try: try:
@@ -3310,6 +3331,11 @@ class WorkflowJobTemplateLabelList(JobTemplateLabelList):
resource_purpose = 'labels of a workflow job template' resource_purpose = 'labels of a workflow job template'
@extend_schema_view(
retrieve=extend_schema(
extensions={'x-ai-description': 'List workflow job template launch criteria.'},
)
)
class WorkflowJobTemplateLaunch(RetrieveAPIView): class WorkflowJobTemplateLaunch(RetrieveAPIView):
model = models.WorkflowJobTemplate model = models.WorkflowJobTemplate
obj_permission_type = 'start' obj_permission_type = 'start'
@@ -3318,6 +3344,9 @@ class WorkflowJobTemplateLaunch(RetrieveAPIView):
resource_purpose = 'launch a workflow job from a workflow job template' resource_purpose = 'launch a workflow job from a workflow job template'
def update_raw_data(self, data): def update_raw_data(self, data):
"""
Use the ID of a workflow job template to retrieve its launch details.
"""
try: try:
obj = self.get_object() obj = self.get_object()
except PermissionDenied: except PermissionDenied:
@@ -3710,6 +3739,11 @@ class JobCancel(GenericCancelView):
return super().post(request, *args, **kwargs) return super().post(request, *args, **kwargs)
@extend_schema_view(
retrieve=extend_schema(
extensions={'x-ai-description': 'List job relaunch criteria'},
)
)
class JobRelaunch(RetrieveAPIView): class JobRelaunch(RetrieveAPIView):
model = models.Job model = models.Job
obj_permission_type = 'start' obj_permission_type = 'start'
@@ -3717,6 +3751,7 @@ class JobRelaunch(RetrieveAPIView):
resource_purpose = 'relaunch a job' resource_purpose = 'relaunch a job'
def update_raw_data(self, data): def update_raw_data(self, data):
"""Use the ID of a job to retrieve data on retry attempts and necessary passwords."""
data = super(JobRelaunch, self).update_raw_data(data) data = super(JobRelaunch, self).update_raw_data(data)
try: try:
obj = self.get_object() obj = self.get_object()

View File

@@ -25,7 +25,6 @@ import requests
from ansible_base.lib.utils.schema import extend_schema_if_available from ansible_base.lib.utils.schema import extend_schema_if_available
from awx import MODE
from awx.api.generics import APIView from awx.api.generics import APIView
from awx.conf.registry import settings_registry from awx.conf.registry import settings_registry
from awx.main.analytics import all_collectors from awx.main.analytics import all_collectors
@@ -33,7 +32,7 @@ from awx.main.ha import is_ha_environment
from awx.main.tasks.system import clear_setting_cache from awx.main.tasks.system import clear_setting_cache
from awx.main.utils import get_awx_version, get_custom_venv_choices from awx.main.utils import get_awx_version, get_custom_venv_choices
from awx.main.utils.licensing import validate_entitlement_manifest from awx.main.utils.licensing import validate_entitlement_manifest
from awx.api.versioning import URLPathVersioning, reverse, drf_reverse from awx.api.versioning import URLPathVersioning, reverse
from awx.main.constants import PRIVILEGE_ESCALATION_METHODS from awx.main.constants import PRIVILEGE_ESCALATION_METHODS
from awx.main.models import Project, Organization, Instance, InstanceGroup, JobTemplate from awx.main.models import Project, Organization, Instance, InstanceGroup, JobTemplate
from awx.main.utils import set_environ from awx.main.utils import set_environ
@@ -62,8 +61,6 @@ class ApiRootView(APIView):
data['custom_logo'] = settings.CUSTOM_LOGO data['custom_logo'] = settings.CUSTOM_LOGO
data['custom_login_info'] = settings.CUSTOM_LOGIN_INFO data['custom_login_info'] = settings.CUSTOM_LOGIN_INFO
data['login_redirect_override'] = settings.LOGIN_REDIRECT_OVERRIDE data['login_redirect_override'] = settings.LOGIN_REDIRECT_OVERRIDE
if MODE == 'development':
data['docs'] = drf_reverse('api:schema-swagger-ui')
return Response(data) return Response(data)

View File

@@ -133,7 +133,7 @@ class WebhookReceiverBase(APIView):
@csrf_exempt @csrf_exempt
@extend_schema_if_available(extensions={"x-ai-description": "Receive a webhook event and trigger a job"}) @extend_schema_if_available(extensions={"x-ai-description": "Receive a webhook event and trigger a job"})
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs_in):
# Ensure that the full contents of the request are captured for multiple uses. # Ensure that the full contents of the request are captured for multiple uses.
request.body request.body

View File

@@ -1,15 +1,17 @@
# Python # Python
import logging import logging
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx.main.analytics.subsystem_metrics import DispatcherMetrics, CallbackReceiverMetrics from awx.main.analytics.subsystem_metrics import DispatcherMetrics, CallbackReceiverMetrics
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler') logger = logging.getLogger('awx.main.scheduler')
@task_awx(queue=get_task_queuename, timeout=300, on_duplicate='discard') @task(queue=get_task_queuename, timeout=300, on_duplicate='discard')
def send_subsystem_metrics(): def send_subsystem_metrics():
DispatcherMetrics().send_metrics() DispatcherMetrics().send_metrics()
CallbackReceiverMetrics().send_metrics() CallbackReceiverMetrics().send_metrics()

View File

@@ -0,0 +1,41 @@
import http.client
import socket
import urllib.error
import urllib.request
import logging
from django.conf import settings
logger = logging.getLogger(__name__)
def get_dispatcherd_metrics(request):
metrics_cfg = settings.METRICS_SUBSYSTEM_CONFIG.get('server', {}).get(settings.METRICS_SERVICE_DISPATCHER, {})
host = metrics_cfg.get('host', 'localhost')
port = metrics_cfg.get('port', 8015)
metrics_filter = []
if request is not None and hasattr(request, "query_params"):
try:
nodes_filter = request.query_params.getlist("node")
except Exception:
nodes_filter = []
if nodes_filter and settings.CLUSTER_HOST_ID not in nodes_filter:
return ''
try:
metrics_filter = request.query_params.getlist("metric")
except Exception:
metrics_filter = []
if metrics_filter:
# Right now we have no way of filtering the dispatcherd metrics
# so just avoid getting in the way if another metric is filtered for
return ''
url = f"http://{host}:{port}/metrics"
try:
with urllib.request.urlopen(url, timeout=1.0) as response:
payload = response.read()
if not payload:
return ''
return payload.decode('utf-8')
except (urllib.error.URLError, UnicodeError, socket.timeout, TimeoutError, http.client.HTTPException) as exc:
logger.debug(f"Failed to collect dispatcherd metrics from {url}: {exc}")
return ''

View File

@@ -15,6 +15,7 @@ from rest_framework.request import Request
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
from awx.main.utils import is_testing from awx.main.utils import is_testing
from awx.main.utils.redis import get_redis_client from awx.main.utils.redis import get_redis_client
from .dispatcherd_metrics import get_dispatcherd_metrics
root_key = settings.SUBSYSTEM_METRICS_REDIS_KEY_PREFIX root_key = settings.SUBSYSTEM_METRICS_REDIS_KEY_PREFIX
logger = logging.getLogger('awx.main.analytics') logger = logging.getLogger('awx.main.analytics')
@@ -398,11 +399,6 @@ class DispatcherMetrics(Metrics):
SetFloatM('workflow_manager_recorded_timestamp', 'Unix timestamp when metrics were last recorded'), SetFloatM('workflow_manager_recorded_timestamp', 'Unix timestamp when metrics were last recorded'),
SetFloatM('workflow_manager_spawn_workflow_graph_jobs_seconds', 'Time spent spawning workflow tasks'), SetFloatM('workflow_manager_spawn_workflow_graph_jobs_seconds', 'Time spent spawning workflow tasks'),
SetFloatM('workflow_manager_get_tasks_seconds', 'Time spent loading workflow tasks from db'), SetFloatM('workflow_manager_get_tasks_seconds', 'Time spent loading workflow tasks from db'),
# dispatcher subsystem metrics
SetIntM('dispatcher_pool_scale_up_events', 'Number of times local dispatcher scaled up a worker since startup'),
SetIntM('dispatcher_pool_active_task_count', 'Number of active tasks in the worker pool when last task was submitted'),
SetIntM('dispatcher_pool_max_worker_count', 'Highest number of workers in worker pool in last collection interval, about 20s'),
SetFloatM('dispatcher_availability', 'Fraction of time (in last collection interval) dispatcher was able to receive messages'),
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -430,8 +426,12 @@ class CallbackReceiverMetrics(Metrics):
def metrics(request): def metrics(request):
output_text = '' output_text = ''
for m in [DispatcherMetrics(), CallbackReceiverMetrics()]: output_text += DispatcherMetrics().generate_metrics(request)
output_text += m.generate_metrics(request) output_text += CallbackReceiverMetrics().generate_metrics(request)
dispatcherd_metrics = get_dispatcherd_metrics(request)
if dispatcherd_metrics:
output_text += dispatcherd_metrics
return output_text return output_text
@@ -481,13 +481,6 @@ class CallbackReceiverMetricsServer(MetricsServer):
super().__init__(settings.METRICS_SERVICE_CALLBACK_RECEIVER, registry) super().__init__(settings.METRICS_SERVICE_CALLBACK_RECEIVER, registry)
class DispatcherMetricsServer(MetricsServer):
def __init__(self):
registry = CollectorRegistry(auto_describe=True)
registry.register(CustomToPrometheusMetricsCollector(DispatcherMetrics(metrics_have_changed=False)))
super().__init__(settings.METRICS_SERVICE_DISPATCHER, registry)
class WebsocketsMetricsServer(MetricsServer): class WebsocketsMetricsServer(MetricsServer):
def __init__(self): def __init__(self):
registry = CollectorRegistry(auto_describe=True) registry = CollectorRegistry(auto_describe=True)

View File

@@ -82,7 +82,7 @@ class MainConfig(AppConfig):
def configure_dispatcherd(self): def configure_dispatcherd(self):
"""This implements the default configuration for dispatcherd """This implements the default configuration for dispatcherd
If running the tasking service like awx-manage run_dispatcher, If running the tasking service like awx-manage dispatcherd,
some additional config will be applied on top of this. some additional config will be applied on top of this.
This configuration provides the minimum such that code can submit This configuration provides the minimum such that code can submit
tasks to pg_notify to run those tasks. tasks to pg_notify to run those tasks.

View File

@@ -77,14 +77,13 @@ class PubSub(object):
n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n yield n
def events(self, yield_timeouts=False): def events(self):
if not self.conn.autocommit: if not self.conn.autocommit:
raise RuntimeError('Listening for events can only be done in autocommit mode') raise RuntimeError('Listening for events can only be done in autocommit mode')
while True: while True:
if select.select([self.conn], [], [], self.select_timeout) == NOT_READY: if select.select([self.conn], [], [], self.select_timeout) == NOT_READY:
if yield_timeouts: yield None
yield None
else: else:
notification_generator = self.current_notifies(self.conn) notification_generator = self.current_notifies(self.conn)
for notification in notification_generator: for notification in notification_generator:

View File

@@ -2,7 +2,7 @@ from django.conf import settings
from ansible_base.lib.utils.db import get_pg_notify_params from ansible_base.lib.utils.db import get_pg_notify_params
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.pool import get_auto_max_workers from awx.main.utils.common import get_auto_max_workers
def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False) -> dict: def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False) -> dict:
@@ -30,7 +30,7 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
}, },
"main_kwargs": {"node_id": settings.CLUSTER_HOST_ID}, "main_kwargs": {"node_id": settings.CLUSTER_HOST_ID},
"process_manager_cls": "ForkServerManager", "process_manager_cls": "ForkServerManager",
"process_manager_kwargs": {"preload_modules": ['awx.main.dispatch.hazmat']}, "process_manager_kwargs": {"preload_modules": ['awx.main.dispatch.prefork']},
}, },
"brokers": {}, "brokers": {},
"publish": {}, "publish": {},
@@ -38,8 +38,8 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
} }
if mock_publish: if mock_publish:
config["brokers"]["noop"] = {} config["brokers"]["dispatcherd.testing.brokers.noop"] = {}
config["publish"]["default_broker"] = "noop" config["publish"]["default_broker"] = "dispatcherd.testing.brokers.noop"
else: else:
config["brokers"]["pg_notify"] = { config["brokers"]["pg_notify"] = {
"config": get_pg_notify_params(), "config": get_pg_notify_params(),
@@ -56,5 +56,11 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
} }
config["brokers"]["pg_notify"]["channels"] = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()] config["brokers"]["pg_notify"]["channels"] = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()]
metrics_cfg = settings.METRICS_SUBSYSTEM_CONFIG.get('server', {}).get(settings.METRICS_SERVICE_DISPATCHER)
if metrics_cfg:
config["service"]["metrics_kwargs"] = {
"host": metrics_cfg.get("host", "localhost"),
"port": metrics_cfg.get("port", 8015),
}
return config return config

View File

@@ -1,77 +0,0 @@
import logging
import uuid
import json
from django.db import connection
from awx.main.dispatch import get_task_queuename
from awx.main.utils.redis import get_redis_client
from . import pg_bus_conn
logger = logging.getLogger('awx.main.dispatch')
class Control(object):
services = ('dispatcher', 'callback_receiver')
result = None
def __init__(self, service, host=None):
if service not in self.services:
raise RuntimeError('{} must be in {}'.format(service, self.services))
self.service = service
self.queuename = host or get_task_queuename()
def status(self, *args, **kwargs):
r = get_redis_client()
if self.service == 'dispatcher':
stats = r.get(f'awx_{self.service}_statistics') or b''
return stats.decode('utf-8')
else:
workers = []
for key in r.keys('awx_callback_receiver_statistics_*'):
workers.append(r.get(key).decode('utf-8'))
return '\n'.join(workers)
def running(self, *args, **kwargs):
return self.control_with_reply('running', *args, **kwargs)
def cancel(self, task_ids, with_reply=True):
if with_reply:
return self.control_with_reply('cancel', extra_data={'task_ids': task_ids})
else:
self.control({'control': 'cancel', 'task_ids': task_ids, 'reply_to': None}, extra_data={'task_ids': task_ids})
def schedule(self, *args, **kwargs):
return self.control_with_reply('schedule', *args, **kwargs)
@classmethod
def generate_reply_queue_name(cls):
return f"reply_to_{str(uuid.uuid4()).replace('-','_')}"
def control_with_reply(self, command, timeout=5, extra_data=None):
logger.warning('checking {} {} for {}'.format(self.service, command, self.queuename))
reply_queue = Control.generate_reply_queue_name()
self.result = None
if not connection.get_autocommit():
raise RuntimeError('Control-with-reply messages can only be done in autocommit mode')
with pg_bus_conn(select_timeout=timeout) as conn:
conn.listen(reply_queue)
send_data = {'control': command, 'reply_to': reply_queue}
if extra_data:
send_data.update(extra_data)
conn.notify(self.queuename, json.dumps(send_data))
for reply in conn.events(yield_timeouts=True):
if reply is None:
logger.error(f'{self.service} did not reply within {timeout}s')
raise RuntimeError(f"{self.service} did not reply within {timeout}s")
break
return json.loads(reply.payload)
def control(self, msg, **kwargs):
with pg_bus_conn() as conn:
conn.notify(self.queuename, json.dumps(msg))

View File

@@ -1,146 +0,0 @@
import logging
import time
import yaml
from datetime import datetime
logger = logging.getLogger('awx.main.dispatch.periodic')
class ScheduledTask:
"""
Class representing schedules, very loosely modeled after python schedule library Job
the idea of this class is to:
- only deal in relative times (time since the scheduler global start)
- only deal in integer math for target runtimes, but float for current relative time
Missed schedule policy:
Invariant target times are maintained, meaning that if interval=10s offset=0
and it runs at t=7s, then it calls for next run in 3s.
However, if a complete interval has passed, that is counted as a missed run,
and missed runs are abandoned (no catch-up runs).
"""
def __init__(self, name: str, data: dict):
# parameters need for schedule computation
self.interval = int(data['schedule'].total_seconds())
self.offset = 0 # offset relative to start time this schedule begins
self.index = 0 # number of periods of the schedule that has passed
# parameters that do not affect scheduling logic
self.last_run = None # time of last run, only used for debug
self.completed_runs = 0 # number of times schedule is known to run
self.name = name
self.data = data # used by caller to know what to run
@property
def next_run(self):
"Time until the next run with t=0 being the global_start of the scheduler class"
return (self.index + 1) * self.interval + self.offset
def due_to_run(self, relative_time):
return bool(self.next_run <= relative_time)
def expected_runs(self, relative_time):
return int((relative_time - self.offset) / self.interval)
def mark_run(self, relative_time):
self.last_run = relative_time
self.completed_runs += 1
new_index = self.expected_runs(relative_time)
if new_index > self.index + 1:
logger.warning(f'Missed {new_index - self.index - 1} schedules of {self.name}')
self.index = new_index
def missed_runs(self, relative_time):
"Number of times job was supposed to ran but failed to, only used for debug"
missed_ct = self.expected_runs(relative_time) - self.completed_runs
# if this is currently due to run do not count that as a missed run
if missed_ct and self.due_to_run(relative_time):
missed_ct -= 1
return missed_ct
class Scheduler:
def __init__(self, schedule):
"""
Expects schedule in the form of a dictionary like
{
'job1': {'schedule': timedelta(seconds=50), 'other': 'stuff'}
}
Only the schedule nearest-second value is used for scheduling,
the rest of the data is for use by the caller to know what to run.
"""
self.jobs = [ScheduledTask(name, data) for name, data in schedule.items()]
min_interval = min(job.interval for job in self.jobs)
num_jobs = len(self.jobs)
# this is intentionally oppioniated against spammy schedules
# a core goal is to spread out the scheduled tasks (for worker management)
# and high-frequency schedules just do not work with that
if num_jobs > min_interval:
raise RuntimeError(f'Number of schedules ({num_jobs}) is more than the shortest schedule interval ({min_interval} seconds).')
# even space out jobs over the base interval
for i, job in enumerate(self.jobs):
job.offset = (i * min_interval) // num_jobs
# internally times are all referenced relative to startup time, add grace period
self.global_start = time.time() + 2.0
def get_and_mark_pending(self, reftime=None):
if reftime is None:
reftime = time.time() # mostly for tests
relative_time = reftime - self.global_start
to_run = []
for job in self.jobs:
if job.due_to_run(relative_time):
to_run.append(job)
logger.debug(f'scheduler found {job.name} to run, {relative_time - job.next_run} seconds after target')
job.mark_run(relative_time)
return to_run
def time_until_next_run(self, reftime=None):
if reftime is None:
reftime = time.time() # mostly for tests
relative_time = reftime - self.global_start
next_job = min(self.jobs, key=lambda j: j.next_run)
delta = next_job.next_run - relative_time
if delta <= 0.1:
# careful not to give 0 or negative values to the select timeout, which has unclear interpretation
logger.warning(f'Scheduler next run of {next_job.name} is {-delta} seconds in the past')
return 0.1
elif delta > 20.0:
logger.warning(f'Scheduler next run unexpectedly over 20 seconds in future: {delta}')
return 20.0
logger.debug(f'Scheduler next run is {next_job.name} in {delta} seconds')
return delta
def debug(self, *args, **kwargs):
data = dict()
data['title'] = 'Scheduler status'
reftime = time.time()
now = datetime.fromtimestamp(reftime).strftime('%Y-%m-%d %H:%M:%S UTC')
start_time = datetime.fromtimestamp(self.global_start).strftime('%Y-%m-%d %H:%M:%S UTC')
relative_time = reftime - self.global_start
data['started_time'] = start_time
data['current_time'] = now
data['current_time_relative'] = round(relative_time, 3)
data['total_schedules'] = len(self.jobs)
data['schedule_list'] = dict(
[
(
job.name,
dict(
last_run_seconds_ago=round(relative_time - job.last_run, 3) if job.last_run else None,
next_run_in_seconds=round(job.next_run - relative_time, 3),
offset_in_seconds=job.offset,
completed_runs=job.completed_runs,
missed_runs=job.missed_runs(relative_time),
),
)
for job in sorted(self.jobs, key=lambda job: job.interval)
]
)
return yaml.safe_dump(data, default_flow_style=False, sort_keys=False)

View File

@@ -1,583 +1,54 @@
import logging import logging
import os
import random
import signal
import sys
import time
import traceback
from datetime import datetime, timezone
from uuid import uuid4
import json
import collections
from multiprocessing import Process from multiprocessing import Process
from multiprocessing import Queue as MPQueue
from queue import Full as QueueFull, Empty as QueueEmpty
from django.conf import settings from django.conf import settings
from django.db import connection as django_connection, connections from django.db import connection as django_connection
from django.core.cache import cache as django_cache from django.core.cache import cache as django_cache
from django.utils.timezone import now as tz_now
from django_guid import set_guid
from jinja2 import Template
import psutil
from ansible_base.lib.logging.runtime import log_excess_runtime logger = logging.getLogger('awx.main.commands.run_callback_receiver')
from awx.main.models import UnifiedJob
from awx.main.dispatch import reaper
from awx.main.utils.common import get_mem_effective_capacity, get_corrected_memory, get_corrected_cpu, get_cpu_effective_capacity
# ansible-runner
from ansible_runner.utils.capacity import get_mem_in_bytes, get_cpu_count
if 'run_callback_receiver' in sys.argv:
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
else:
logger = logging.getLogger('awx.main.dispatch')
RETIRED_SENTINEL_TASK = "[retired]"
class NoOpResultQueue(object):
def put(self, item):
pass
class PoolWorker(object): class PoolWorker(object):
""" """
Used to track a worker child process and its pending and finished messages. A simple wrapper around a multiprocessing.Process that tracks a worker child process.
This class makes use of two distinct multiprocessing.Queues to track state: The worker process runs the provided target function.
- self.queue: this is a queue which represents pending messages that should
be handled by this worker process; as new AMQP messages come
in, a pool will put() them into this queue; the child
process that is forked will get() from this queue and handle
received messages in an endless loop
- self.finished: this is a queue which the worker process uses to signal
that it has finished processing a message
When a message is put() onto this worker, it is tracked in
self.managed_tasks.
Periodically, the worker will call .calculate_managed_tasks(), which will
cause messages in self.finished to be removed from self.managed_tasks.
In this way, self.managed_tasks represents a view of the messages assigned
to a specific process. The message at [0] is the least-recently inserted
message, and it represents what the worker is running _right now_
(self.current_task).
A worker is "busy" when it has at least one message in self.managed_tasks.
It is "idle" when self.managed_tasks is empty.
""" """
track_managed_tasks = False def __init__(self, target, args):
self.process = Process(target=target, args=args)
def __init__(self, queue_size, target, args, **kwargs):
self.messages_sent = 0
self.messages_finished = 0
self.managed_tasks = collections.OrderedDict()
self.finished = MPQueue(queue_size) if self.track_managed_tasks else NoOpResultQueue()
self.queue = MPQueue(queue_size)
self.process = Process(target=target, args=(self.queue, self.finished) + args)
self.process.daemon = True self.process.daemon = True
self.creation_time = time.monotonic()
self.retiring = False
def start(self): def start(self):
self.process.start() self.process.start()
def put(self, body):
if self.retiring:
uuid = body.get('uuid', 'N/A') if isinstance(body, dict) else 'N/A'
logger.info(f"Worker pid:{self.pid} is retiring. Refusing new task {uuid}.")
raise QueueFull("Worker is retiring and not accepting new tasks") # AutoscalePool.write handles QueueFull
uuid = '?'
if isinstance(body, dict):
if not body.get('uuid'):
body['uuid'] = str(uuid4())
uuid = body['uuid']
if self.track_managed_tasks:
self.managed_tasks[uuid] = body
self.queue.put(body, block=True, timeout=5)
self.messages_sent += 1
self.calculate_managed_tasks()
def quit(self):
"""
Send a special control message to the worker that tells it to exit
gracefully.
"""
self.queue.put('QUIT')
@property
def age(self):
"""Returns the current age of the worker in seconds."""
return time.monotonic() - self.creation_time
@property
def pid(self):
return self.process.pid
@property
def qsize(self):
return self.queue.qsize()
@property
def alive(self):
return self.process.is_alive()
@property
def mb(self):
if self.alive:
return '{:0.3f}'.format(psutil.Process(self.pid).memory_info().rss / 1024.0 / 1024.0)
return '0'
@property
def exitcode(self):
return str(self.process.exitcode)
def calculate_managed_tasks(self):
if not self.track_managed_tasks:
return
# look to see if any tasks were finished
finished = []
for _ in range(self.finished.qsize()):
try:
finished.append(self.finished.get(block=False))
except QueueEmpty:
break # qsize is not always _totally_ up to date
# if any tasks were finished, removed them from the managed tasks for
# this worker
for uuid in finished:
try:
del self.managed_tasks[uuid]
self.messages_finished += 1
except KeyError:
# ansible _sometimes_ appears to send events w/ duplicate UUIDs;
# UUIDs for ansible events are *not* actually globally unique
# when this occurs, it's _fine_ to ignore this KeyError because
# the purpose of self.managed_tasks is to just track internal
# state of which events are *currently* being processed.
logger.warning('Event UUID {} appears to be have been duplicated.'.format(uuid))
if self.retiring:
self.managed_tasks[RETIRED_SENTINEL_TASK] = {'task': RETIRED_SENTINEL_TASK}
@property
def current_task(self):
if not self.track_managed_tasks:
return None
self.calculate_managed_tasks()
# the task at [0] is the one that's running right now (or is about to
# be running)
if len(self.managed_tasks):
return self.managed_tasks[list(self.managed_tasks.keys())[0]]
return None
@property
def orphaned_tasks(self):
if not self.track_managed_tasks:
return []
orphaned = []
if not self.alive:
# if this process had a running task that never finished,
# requeue its error callbacks
current_task = self.current_task
if isinstance(current_task, dict):
orphaned.extend(current_task.get('errbacks', []))
# if this process has any pending messages requeue them
for _ in range(self.qsize):
try:
message = self.queue.get(block=False)
if message != 'QUIT':
orphaned.append(message)
except QueueEmpty:
break # qsize is not always _totally_ up to date
if len(orphaned):
logger.error('requeuing {} messages from gone worker pid:{}'.format(len(orphaned), self.pid))
return orphaned
@property
def busy(self):
self.calculate_managed_tasks()
return len(self.managed_tasks) > 0
@property
def idle(self):
return not self.busy
class StatefulPoolWorker(PoolWorker):
track_managed_tasks = True
class WorkerPool(object): class WorkerPool(object):
""" """
Creates a pool of forked PoolWorkers. Creates a pool of forked PoolWorkers.
As WorkerPool.write(...) is called (generally, by a kombu consumer Each worker process runs the provided target function in an isolated process.
implementation when it receives an AMQP message), messages are passed to The pool manages spawning, tracking, and stopping worker processes.
one of the multiprocessing Queues where some work can be done on them.
class MessagePrinter(awx.main.dispatch.worker.BaseWorker): Example:
pool = WorkerPool(workers_num=4) # spawn four worker processes
def perform_work(self, body):
print(body)
pool = WorkerPool(min_workers=4) # spawn four worker processes
pool.init_workers(MessagePrint().work_loop)
pool.write(
0, # preferred worker 0
'Hello, World!'
)
""" """
pool_cls = PoolWorker def __init__(self, workers_num=None):
debug_meta = '' self.workers_num = workers_num or settings.JOB_EVENT_WORKERS
def __init__(self, min_workers=None, queue_size=None): def init_workers(self, target):
self.name = settings.CLUSTER_HOST_ID for idx in range(self.workers_num):
self.pid = os.getpid() # It's important to close these because we're _about_ to fork, and we
self.min_workers = min_workers or settings.JOB_EVENT_WORKERS # don't want the forked processes to inherit the open sockets
self.queue_size = queue_size or settings.JOB_EVENT_MAX_QUEUE_SIZE # for the DB and cache connections (that way lies race conditions)
self.workers = [] django_connection.close()
django_cache.close()
def __len__(self): worker = PoolWorker(target, (idx,))
return len(self.workers)
def init_workers(self, target, *target_args):
self.target = target
self.target_args = target_args
for idx in range(self.min_workers):
self.up()
def up(self):
idx = len(self.workers)
# It's important to close these because we're _about_ to fork, and we
# don't want the forked processes to inherit the open sockets
# for the DB and cache connections (that way lies race conditions)
django_connection.close()
django_cache.close()
worker = self.pool_cls(self.queue_size, self.target, (idx,) + self.target_args)
self.workers.append(worker)
try:
worker.start()
except Exception:
logger.exception('could not fork')
else:
logger.debug('scaling up worker pid:{}'.format(worker.pid))
return idx, worker
def debug(self, *args, **kwargs):
tmpl = Template(
'Recorded at: {{ dt }} \n'
'{{ pool.name }}[pid:{{ pool.pid }}] workers total={{ workers|length }} {{ meta }} \n'
'{% for w in workers %}'
'. worker[pid:{{ w.pid }}]{% if not w.alive %} GONE exit={{ w.exitcode }}{% endif %}'
' sent={{ w.messages_sent }}'
' age={{ "%.0f"|format(w.age) }}s'
' retiring={{ w.retiring }}'
'{% if w.messages_finished %} finished={{ w.messages_finished }}{% endif %}'
' qsize={{ w.managed_tasks|length }}'
' rss={{ w.mb }}MB'
'{% for task in w.managed_tasks.values() %}'
'\n - {% if loop.index0 == 0 %}running {% if "age" in task %}for: {{ "%.1f" % task["age"] }}s {% endif %}{% else %}queued {% endif %}'
'{{ task["uuid"] }} '
'{% if "task" in task %}'
'{{ task["task"].rsplit(".", 1)[-1] }}'
# don't print kwargs, they often contain launch-time secrets
'(*{{ task.get("args", []) }})'
'{% endif %}'
'{% endfor %}'
'{% if not w.managed_tasks|length %}'
' [IDLE]'
'{% endif %}'
'\n'
'{% endfor %}'
)
now = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
return tmpl.render(pool=self, workers=self.workers, meta=self.debug_meta, dt=now)
def write(self, preferred_queue, body):
queue_order = sorted(range(len(self.workers)), key=lambda x: -1 if x == preferred_queue else x)
write_attempt_order = []
for queue_actual in queue_order:
try: try:
self.workers[queue_actual].put(body) worker.start()
return queue_actual
except QueueFull:
pass
except Exception: except Exception:
tb = traceback.format_exc() logger.exception('could not fork')
logger.warning("could not write to queue %s" % preferred_queue)
logger.warning("detail: {}".format(tb))
write_attempt_order.append(preferred_queue)
logger.error("could not write payload to any queue, attempted order: {}".format(write_attempt_order))
return None
def stop(self, signum):
try:
for worker in self.workers:
os.kill(worker.pid, signum)
except Exception:
logger.exception('could not kill {}'.format(worker.pid))
def get_auto_max_workers():
"""Method we normally rely on to get max_workers
Uses almost same logic as Instance.local_health_check
The important thing is to be MORE than Instance.capacity
so that the task-manager does not over-schedule this node
Ideally we would just use the capacity from the database plus reserve workers,
but this poses some bootstrap problems where OCP task containers
register themselves after startup
"""
# Get memory from ansible-runner
total_memory_gb = get_mem_in_bytes()
# This may replace memory calculation with a user override
corrected_memory = get_corrected_memory(total_memory_gb)
# Get same number as max forks based on memory, this function takes memory as bytes
mem_capacity = get_mem_effective_capacity(corrected_memory, is_control_node=True)
# Follow same process for CPU capacity constraint
cpu_count = get_cpu_count()
corrected_cpu = get_corrected_cpu(cpu_count)
cpu_capacity = get_cpu_effective_capacity(corrected_cpu, is_control_node=True)
# Here is what is different from health checks,
auto_max = max(mem_capacity, cpu_capacity)
# add magic number of extra workers to ensure
# we have a few extra workers to run the heartbeat
auto_max += 7
return auto_max
class AutoscalePool(WorkerPool):
"""
An extended pool implementation that automatically scales workers up and
down based on demand
"""
pool_cls = StatefulPoolWorker
def __init__(self, *args, **kwargs):
self.max_workers = kwargs.pop('max_workers', None)
self.max_worker_lifetime_seconds = kwargs.pop(
'max_worker_lifetime_seconds', getattr(settings, 'WORKER_MAX_LIFETIME_SECONDS', 14400)
) # Default to 4 hours
super(AutoscalePool, self).__init__(*args, **kwargs)
if self.max_workers is None:
self.max_workers = get_auto_max_workers()
# max workers can't be less than min_workers
self.max_workers = max(self.min_workers, self.max_workers)
# the task manager enforces settings.TASK_MANAGER_TIMEOUT on its own
# but if the task takes longer than the time defined here, we will force it to stop here
self.task_manager_timeout = settings.TASK_MANAGER_TIMEOUT + settings.TASK_MANAGER_TIMEOUT_GRACE_PERIOD
# initialize some things for subsystem metrics periodic gathering
# the AutoscalePool class does not save these to redis directly, but reports via produce_subsystem_metrics
self.scale_up_ct = 0
self.worker_count_max = 0
# last time we wrote current tasks, to avoid too much log spam
self.last_task_list_log = time.monotonic()
def produce_subsystem_metrics(self, metrics_object):
metrics_object.set('dispatcher_pool_scale_up_events', self.scale_up_ct)
metrics_object.set('dispatcher_pool_active_task_count', sum(len(w.managed_tasks) for w in self.workers))
metrics_object.set('dispatcher_pool_max_worker_count', self.worker_count_max)
self.worker_count_max = len(self.workers)
@property
def should_grow(self):
if len(self.workers) < self.min_workers:
# If we don't have at least min_workers, add more
return True
# If every worker is busy doing something, add more
return all([w.busy for w in self.workers])
@property
def full(self):
return len(self.workers) == self.max_workers
@property
def debug_meta(self):
return 'min={} max={}'.format(self.min_workers, self.max_workers)
@log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2)
def cleanup(self):
"""
Perform some internal account and cleanup. This is run on
every cluster node heartbeat:
1. Discover worker processes that exited, and recover messages they
were handling.
2. Clean up unnecessary, idle workers.
IMPORTANT: this function is one of the few places in the dispatcher
(aside from setting lookups) where we talk to the database. As such,
if there's an outage, this method _can_ throw various
django.db.utils.Error exceptions. Act accordingly.
"""
orphaned = []
for w in self.workers[::]:
is_retirement_age = self.max_worker_lifetime_seconds is not None and w.age > self.max_worker_lifetime_seconds
if not w.alive:
# the worker process has exited
# 1. take the task it was running and enqueue the error
# callbacks
# 2. take any pending tasks delivered to its queue and
# send them to another worker
logger.error('worker pid:{} is gone (exit={})'.format(w.pid, w.exitcode))
if w.current_task:
if w.current_task == {'task': RETIRED_SENTINEL_TASK}:
logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age))
self.workers.remove(w)
continue
if w.current_task != 'QUIT':
try:
for j in UnifiedJob.objects.filter(celery_task_id=w.current_task['uuid']):
reaper.reap_job(j, 'failed')
except Exception:
logger.exception('failed to reap job UUID {}'.format(w.current_task['uuid']))
else:
logger.warning(f'Worker was told to quit but has not, pid={w.pid}')
orphaned.extend(w.orphaned_tasks)
self.workers.remove(w)
elif w.idle and len(self.workers) > self.min_workers:
# the process has an empty queue (it's idle) and we have
# more processes in the pool than we need (> min)
# send this process a message so it will exit gracefully
# at the next opportunity
logger.debug('scaling down worker pid:{}'.format(w.pid))
w.quit()
self.workers.remove(w)
elif w.idle and is_retirement_age:
logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age))
w.quit()
self.workers.remove(w)
elif is_retirement_age and not w.retiring and not w.idle:
logger.info(
f"Worker pid:{w.pid} (age: {w.age:.0f}s) exceeded max lifetime ({self.max_worker_lifetime_seconds:.0f}s). "
"Signaling for graceful retirement."
)
# Send QUIT signal; worker will finish current task then exit.
w.quit()
# mark as retiring to reject any future tasks that might be assigned in meantime
w.retiring = True
if w.alive:
# if we discover a task manager invocation that's been running
# too long, reap it (because otherwise it'll just hold the postgres
# advisory lock forever); the goal of this code is to discover
# deadlocks or other serious issues in the task manager that cause
# the task manager to never do more work
current_task = w.current_task
if current_task and isinstance(current_task, dict):
endings = ('tasks.task_manager', 'tasks.dependency_manager', 'tasks.workflow_manager')
current_task_name = current_task.get('task', '')
if current_task_name.endswith(endings):
if 'started' not in current_task:
w.managed_tasks[current_task['uuid']]['started'] = time.time()
age = time.time() - current_task['started']
w.managed_tasks[current_task['uuid']]['age'] = age
if age > self.task_manager_timeout:
logger.error(f'{current_task_name} has held the advisory lock for {age}, sending SIGUSR1 to {w.pid}')
os.kill(w.pid, signal.SIGUSR1)
for m in orphaned:
# if all the workers are dead, spawn at least one
if not len(self.workers):
self.up()
idx = random.choice(range(len(self.workers)))
self.write(idx, m)
def add_bind_kwargs(self, body):
bind_kwargs = body.pop('bind_kwargs', [])
body.setdefault('kwargs', {})
if 'dispatch_time' in bind_kwargs:
body['kwargs']['dispatch_time'] = tz_now().isoformat()
if 'worker_tasks' in bind_kwargs:
worker_tasks = {}
for worker in self.workers:
worker.calculate_managed_tasks()
worker_tasks[worker.pid] = list(worker.managed_tasks.keys())
body['kwargs']['worker_tasks'] = worker_tasks
def up(self):
if self.full:
# if we can't spawn more workers, just toss this message into a
# random worker's backlog
idx = random.choice(range(len(self.workers)))
return idx, self.workers[idx]
else:
self.scale_up_ct += 1
ret = super(AutoscalePool, self).up()
new_worker_ct = len(self.workers)
if new_worker_ct > self.worker_count_max:
self.worker_count_max = new_worker_ct
return ret
@staticmethod
def fast_task_serialization(current_task):
try:
return str(current_task.get('task')) + ' - ' + str(sorted(current_task.get('args', []))) + ' - ' + str(sorted(current_task.get('kwargs', {})))
except Exception:
# just make sure this does not make things worse
return str(current_task)
def write(self, preferred_queue, body):
if 'guid' in body:
set_guid(body['guid'])
try:
if isinstance(body, dict) and body.get('bind_kwargs'):
self.add_bind_kwargs(body)
if self.should_grow:
self.up()
# we don't care about "preferred queue" round robin distribution, just
# find the first non-busy worker and claim it
workers = self.workers[:]
random.shuffle(workers)
for w in workers:
if not w.busy:
w.put(body)
break
else: else:
task_name = 'unknown' logger.debug('scaling up worker pid:{}'.format(worker.process.pid))
if isinstance(body, dict):
task_name = body.get('task')
logger.warning(f'Workers maxed, queuing {task_name}, load: {sum(len(w.managed_tasks) for w in self.workers)} / {len(self.workers)}')
# Once every 10 seconds write out task list for debugging
if time.monotonic() - self.last_task_list_log >= 10.0:
task_counts = {}
for worker in self.workers:
task_slug = self.fast_task_serialization(worker.current_task)
task_counts.setdefault(task_slug, 0)
task_counts[task_slug] += 1
logger.info(f'Running tasks by count:\n{json.dumps(task_counts, indent=2)}')
self.last_task_list_log = time.monotonic()
return super(AutoscalePool, self).write(preferred_queue, body)
except Exception:
for conn in connections.all():
# If the database connection has a hiccup, re-establish a new
# connection
conn.close_if_unusable_or_obsolete()
logger.exception('failed to write inbound message')

View File

@@ -18,7 +18,7 @@ django.setup() # noqa
from django.conf import settings from django.conf import settings
# Preload all periodic tasks so their imports will be in shared memory # Preload all periodic tasks so their imports will be in shared memory
for name, options in settings.CELERYBEAT_SCHEDULE.items(): for name, options in settings.DISPATCHER_SCHEDULE.items():
resolve_callable(options['task']) resolve_callable(options['task'])

View File

@@ -1,163 +0,0 @@
import inspect
import logging
import json
import time
from uuid import uuid4
from dispatcherd.publish import submit_task
from dispatcherd.processors.blocker import Blocker
from dispatcherd.utils import resolve_callable
from django_guid import get_guid
from django.conf import settings
from . import pg_bus_conn
logger = logging.getLogger('awx.main.dispatch')
def serialize_task(f):
return '.'.join([f.__module__, f.__name__])
class task:
"""
Used to decorate a function or class so that it can be run asynchronously
via the task dispatcher. Tasks can be simple functions:
@task()
def add(a, b):
return a + b
...or classes that define a `run` method:
@task()
class Adder:
def run(self, a, b):
return a + b
# Tasks can be run synchronously...
assert add(1, 1) == 2
assert Adder().run(1, 1) == 2
# ...or published to a queue:
add.apply_async([1, 1])
Adder.apply_async([1, 1])
# Tasks can also define a specific target queue or use the special fan-out queue tower_broadcast:
@task(queue='slow-tasks')
def snooze():
time.sleep(10)
@task(queue='tower_broadcast')
def announce():
print("Run this everywhere!")
# The special parameter bind_kwargs tells the main dispatcher process to add certain kwargs
@task(bind_kwargs=['dispatch_time'])
def print_time(dispatch_time=None):
print(f"Time I was dispatched: {dispatch_time}")
"""
def __init__(self, queue=None, bind_kwargs=None, timeout=None, on_duplicate=None):
self.queue = queue
self.bind_kwargs = bind_kwargs
self.timeout = timeout
self.on_duplicate = on_duplicate
def __call__(self, fn=None):
queue = self.queue
bind_kwargs = self.bind_kwargs
timeout = self.timeout
on_duplicate = self.on_duplicate
class PublisherMixin(object):
queue = None
@classmethod
def delay(cls, *args, **kwargs):
return cls.apply_async(args, kwargs)
@classmethod
def get_async_body(cls, args=None, kwargs=None, uuid=None, **kw):
"""
Get the python dict to become JSON data in the pg_notify message
This same message gets passed over the dispatcher IPC queue to workers
If a task is submitted to a multiprocessing pool, skipping pg_notify, this might be used directly
"""
task_id = uuid or str(uuid4())
args = args or []
kwargs = kwargs or {}
obj = {'uuid': task_id, 'args': args, 'kwargs': kwargs, 'task': cls.name, 'time_pub': time.time()}
guid = get_guid()
if guid:
obj['guid'] = guid
if bind_kwargs:
obj['bind_kwargs'] = bind_kwargs
obj.update(**kw)
return obj
@classmethod
def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
try:
from flags.state import flag_enabled
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
# At this point we have the import string, and submit_task wants the method, so back to that
actual_task = resolve_callable(cls.name)
processor_options = ()
if on_duplicate is not None:
processor_options = (Blocker.Params(on_duplicate=on_duplicate),)
return submit_task(
actual_task,
args=args,
kwargs=kwargs,
queue=queue,
uuid=uuid,
timeout=timeout,
processor_options=processor_options,
**kw,
)
except Exception:
logger.exception(f"[DISPATCHER] Failed to check for alternative dispatcherd implementation for {cls.name}")
# Continue with original implementation if anything fails
pass
# Original implementation follows
queue = queue or getattr(cls.queue, 'im_func', cls.queue)
if not queue:
msg = f'{cls.name}: Queue value required and may not be None'
logger.error(msg)
raise ValueError(msg)
obj = cls.get_async_body(args=args, kwargs=kwargs, uuid=uuid, **kw)
if callable(queue):
queue = queue()
if not settings.DISPATCHER_MOCK_PUBLISH:
with pg_bus_conn() as conn:
conn.notify(queue, json.dumps(obj))
return (obj, queue)
# If the object we're wrapping *is* a class (e.g., RunJob), return
# a *new* class that inherits from the wrapped class *and* BaseTask
# In this way, the new class returned by our decorator is the class
# being decorated *plus* PublisherMixin so cls.apply_async() and
# cls.delay() work
bases = []
ns = {'name': serialize_task(fn), 'queue': queue}
if inspect.isclass(fn):
bases = list(fn.__bases__)
ns.update(fn.__dict__)
cls = type(fn.__name__, tuple(bases + [PublisherMixin]), ns)
if inspect.isclass(fn):
return cls
# if the object being decorated is *not* a class (it's a Python
# function), make fn.apply_async and fn.delay proxy through to the
# PublisherMixin we dynamically created above
setattr(fn, 'name', cls.name)
setattr(fn, 'apply_async', cls.apply_async)
setattr(fn, 'delay', cls.delay)
setattr(fn, 'get_async_body', cls.get_async_body)
return fn

View File

@@ -1,9 +1,6 @@
from datetime import timedelta
import logging import logging
from django.db.models import Q from django.db.models import Q
from django.conf import settings
from django.utils.timezone import now as tz_now
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from awx.main.models import Instance, UnifiedJob, WorkflowJob from awx.main.models import Instance, UnifiedJob, WorkflowJob
@@ -50,26 +47,6 @@ def reap_job(j, status, job_explanation=None):
logger.error(f'{j.log_format} is no longer {status_before}; reaping') logger.error(f'{j.log_format} is no longer {status_before}; reaping')
def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None, ref_time=None):
"""
Reap all jobs in waiting for this instance.
"""
if grace_period is None:
grace_period = settings.JOB_WAITING_GRACE_PERIOD + settings.TASK_MANAGER_TIMEOUT
if instance is None:
hostname = Instance.objects.my_hostname()
else:
hostname = instance.hostname
if ref_time is None:
ref_time = tz_now()
jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=ref_time - timedelta(seconds=grace_period), controller_node=hostname)
if excluded_uuids:
jobs = jobs.exclude(celery_task_id__in=excluded_uuids)
for j in jobs:
reap_job(j, status, job_explanation=job_explanation)
def reap(instance=None, status='failed', job_explanation=None, excluded_uuids=None, ref_time=None): def reap(instance=None, status='failed', job_explanation=None, excluded_uuids=None, ref_time=None):
""" """
Reap all jobs in running for this instance. Reap all jobs in running for this instance.

View File

@@ -1,3 +1,2 @@
from .base import AWXConsumerRedis, AWXConsumerPG, BaseWorker # noqa from .base import AWXConsumerRedis # noqa
from .callback import CallbackBrokerWorker # noqa from .callback import CallbackBrokerWorker # noqa
from .task import TaskWorker # noqa

View File

@@ -4,342 +4,39 @@
import os import os
import logging import logging
import signal import signal
import sys
import redis
import json
import psycopg
import time import time
from uuid import UUID
from queue import Empty as QueueEmpty
from datetime import timedelta
from django import db from django import db
from django.conf import settings
import redis.exceptions
from ansible_base.lib.logging.runtime import log_excess_runtime
from awx.main.utils.redis import get_redis_client from awx.main.utils.redis import get_redis_client
from awx.main.dispatch.pool import WorkerPool from awx.main.dispatch.pool import WorkerPool
from awx.main.dispatch.periodic import Scheduler
from awx.main.dispatch import pg_bus_conn
from awx.main.utils.db import set_connection_name
import awx.main.analytics.subsystem_metrics as s_metrics
if 'run_callback_receiver' in sys.argv: logger = logging.getLogger('awx.main.commands.run_callback_receiver')
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
else:
logger = logging.getLogger('awx.main.dispatch')
def signame(sig): def signame(sig):
return dict((k, v) for v, k in signal.__dict__.items() if v.startswith('SIG') and not v.startswith('SIG_'))[sig] return dict((k, v) for v, k in signal.__dict__.items() if v.startswith('SIG') and not v.startswith('SIG_'))[sig]
class WorkerSignalHandler: class AWXConsumerRedis(object):
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class AWXConsumerBase(object):
last_stats = time.time()
def __init__(self, name, worker, queues=[], pool=None):
self.should_stop = False
def __init__(self, name, worker):
self.name = name self.name = name
self.total_messages = 0 self.pool = WorkerPool()
self.queues = queues self.pool.init_workers(worker.work_loop)
self.worker = worker
self.pool = pool
if pool is None:
self.pool = WorkerPool()
self.pool.init_workers(self.worker.work_loop)
self.redis = get_redis_client() self.redis = get_redis_client()
@property def run(self):
def listening_on(self):
return f'listening on {self.queues}'
def control(self, body):
logger.warning(f'Received control signal:\n{body}')
control = body.get('control')
if control in ('status', 'schedule', 'running', 'cancel'):
reply_queue = body['reply_to']
if control == 'status':
msg = '\n'.join([self.listening_on, self.pool.debug()])
if control == 'schedule':
msg = self.scheduler.debug()
elif control == 'running':
msg = []
for worker in self.pool.workers:
worker.calculate_managed_tasks()
msg.extend(worker.managed_tasks.keys())
elif control == 'cancel':
msg = []
task_ids = set(body['task_ids'])
for worker in self.pool.workers:
task = worker.current_task
if task and task['uuid'] in task_ids:
logger.warn(f'Sending SIGTERM to task id={task["uuid"]}, task={task.get("task")}, args={task.get("args")}')
os.kill(worker.pid, signal.SIGTERM)
msg.append(task['uuid'])
if task_ids and not msg:
logger.info(f'Could not locate running tasks to cancel with ids={task_ids}')
if reply_queue is not None:
with pg_bus_conn() as conn:
conn.notify(reply_queue, json.dumps(msg))
elif control == 'reload':
for worker in self.pool.workers:
worker.quit()
else:
logger.error('unrecognized control message: {}'.format(control))
def dispatch_task(self, body):
"""This will place the given body into a worker queue to run method decorated as a task"""
if isinstance(body, dict):
body['time_ack'] = time.time()
if len(self.pool):
if "uuid" in body and body['uuid']:
try:
queue = UUID(body['uuid']).int % len(self.pool)
except Exception:
queue = self.total_messages % len(self.pool)
else:
queue = self.total_messages % len(self.pool)
else:
queue = 0
self.pool.write(queue, body)
self.total_messages += 1
def process_task(self, body):
"""Routes the task details in body as either a control task or a task-task"""
if 'control' in body:
try:
return self.control(body)
except Exception:
logger.exception(f"Exception handling control message: {body}")
return
self.dispatch_task(body)
@log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2)
def record_statistics(self):
if time.time() - self.last_stats > 1: # buffer stat recording to once per second
save_data = self.pool.debug()
try:
self.redis.set(f'awx_{self.name}_statistics', save_data)
except redis.exceptions.ConnectionError as exc:
logger.warning(f'Redis connection error saving {self.name} status data:\n{exc}\nmissed data:\n{save_data}')
except Exception:
logger.exception(f"Unknown redis error saving {self.name} status data:\nmissed data:\n{save_data}")
self.last_stats = time.time()
def run(self, *args, **kwargs):
signal.signal(signal.SIGINT, self.stop) signal.signal(signal.SIGINT, self.stop)
signal.signal(signal.SIGTERM, self.stop) signal.signal(signal.SIGTERM, self.stop)
# Child should implement other things here
def stop(self, signum, frame):
self.should_stop = True
logger.warning('received {}, stopping'.format(signame(signum)))
self.worker.on_stop()
raise SystemExit()
class AWXConsumerRedis(AWXConsumerBase):
def run(self, *args, **kwargs):
super(AWXConsumerRedis, self).run(*args, **kwargs)
self.worker.on_start()
logger.info(f'Callback receiver started with pid={os.getpid()}') logger.info(f'Callback receiver started with pid={os.getpid()}')
db.connection.close() # logs use database, so close connection db.connection.close() # logs use database, so close connection
while True: while True:
time.sleep(60) time.sleep(60)
def stop(self, signum, frame):
class AWXConsumerPG(AWXConsumerBase): logger.warning('received {}, stopping'.format(signame(signum)))
def __init__(self, *args, schedule=None, **kwargs): raise SystemExit()
super().__init__(*args, **kwargs)
self.pg_max_wait = getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE)
# if no successful loops have ran since startup, then we should fail right away
self.pg_is_down = True # set so that we fail if we get database errors on startup
init_time = time.time()
self.pg_down_time = init_time - self.pg_max_wait # allow no grace period
self.last_cleanup = init_time
self.subsystem_metrics = s_metrics.DispatcherMetrics(auto_pipe_execute=False)
self.last_metrics_gather = init_time
self.listen_cumulative_time = 0.0
if schedule:
schedule = schedule.copy()
else:
schedule = {}
# add control tasks to be ran at regular schedules
# NOTE: if we run out of database connections, it is important to still run cleanup
# so that we scale down workers and free up connections
schedule['pool_cleanup'] = {'control': self.pool.cleanup, 'schedule': timedelta(seconds=60)}
# record subsystem metrics for the dispatcher
schedule['metrics_gather'] = {'control': self.record_metrics, 'schedule': timedelta(seconds=20)}
self.scheduler = Scheduler(schedule)
@log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2)
def record_metrics(self):
current_time = time.time()
self.pool.produce_subsystem_metrics(self.subsystem_metrics)
self.subsystem_metrics.set('dispatcher_availability', self.listen_cumulative_time / (current_time - self.last_metrics_gather))
try:
self.subsystem_metrics.pipe_execute()
except redis.exceptions.ConnectionError as exc:
logger.warning(f'Redis connection error saving dispatcher metrics, error:\n{exc}')
self.listen_cumulative_time = 0.0
self.last_metrics_gather = current_time
def run_periodic_tasks(self):
"""
Run general periodic logic, and return maximum time in seconds before
the next requested run
This may be called more often than that when events are consumed
so this should be very efficient in that
"""
try:
self.record_statistics() # maintains time buffer in method
except Exception as exc:
logger.warning(f'Failed to save dispatcher statistics {exc}')
# Everything benchmarks to the same original time, so that skews due to
# runtime of the actions, themselves, do not mess up scheduling expectations
reftime = time.time()
for job in self.scheduler.get_and_mark_pending(reftime=reftime):
if 'control' in job.data:
try:
job.data['control']()
except Exception:
logger.exception(f'Error running control task {job.data}')
elif 'task' in job.data:
body = self.worker.resolve_callable(job.data['task']).get_async_body()
# bypasses pg_notify for scheduled tasks
self.dispatch_task(body)
if self.pg_is_down:
logger.info('Dispatcher listener connection established')
self.pg_is_down = False
self.listen_start = time.time()
return self.scheduler.time_until_next_run(reftime=reftime)
def run(self, *args, **kwargs):
super(AWXConsumerPG, self).run(*args, **kwargs)
logger.info(f"Running {self.name}, workers min={self.pool.min_workers} max={self.pool.max_workers}, listening to queues {self.queues}")
init = False
while True:
try:
with pg_bus_conn(new_connection=True) as conn:
for queue in self.queues:
conn.listen(queue)
if init is False:
self.worker.on_start()
init = True
# run_periodic_tasks run scheduled actions and gives time until next scheduled action
# this is saved to the conn (PubSub) object in order to modify read timeout in-loop
conn.select_timeout = self.run_periodic_tasks()
# this is the main operational loop for awx-manage run_dispatcher
for e in conn.events(yield_timeouts=True):
self.listen_cumulative_time += time.time() - self.listen_start # for metrics
if e is not None:
self.process_task(json.loads(e.payload))
conn.select_timeout = self.run_periodic_tasks()
if self.should_stop:
return
except psycopg.InterfaceError:
logger.warning("Stale Postgres message bus connection, reconnecting")
continue
except (db.DatabaseError, psycopg.OperationalError):
# If we have attained stady state operation, tolerate short-term database hickups
if not self.pg_is_down:
logger.exception(f"Error consuming new events from postgres, will retry for {self.pg_max_wait} s")
self.pg_down_time = time.time()
self.pg_is_down = True
current_downtime = time.time() - self.pg_down_time
if current_downtime > self.pg_max_wait:
logger.exception(f"Postgres event consumer has not recovered in {current_downtime} s, exiting")
# Sending QUIT to multiprocess queue to signal workers to exit
for worker in self.pool.workers:
try:
worker.quit()
except Exception:
logger.exception(f"Error sending QUIT to worker {worker}")
raise
# Wait for a second before next attempt, but still listen for any shutdown signals
for i in range(10):
if self.should_stop:
return
time.sleep(0.1)
for conn in db.connections.all():
conn.close_if_unusable_or_obsolete()
except Exception:
# Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata
logger.exception('Encountered unhandled error in dispatcher main loop')
# Sending QUIT to multiprocess queue to signal workers to exit
for worker in self.pool.workers:
try:
worker.quit()
except Exception:
logger.exception(f"Error sending QUIT to worker {worker}")
raise
class BaseWorker(object):
def read(self, queue):
return queue.get(block=True, timeout=1)
def work_loop(self, queue, finished, idx, *args):
ppid = os.getppid()
signal_handler = WorkerSignalHandler()
set_connection_name('worker') # set application_name to distinguish from other dispatcher processes
while not signal_handler.kill_now:
# if the parent PID changes, this process has been orphaned
# via e.g., segfault or sigkill, we should exit too
if os.getppid() != ppid:
break
try:
body = self.read(queue)
if body == 'QUIT':
break
except QueueEmpty:
continue
except Exception:
logger.exception("Exception on worker {}, reconnecting: ".format(idx))
continue
try:
for conn in db.connections.all():
# If the database connection has a hiccup during the prior message, close it
# so we can establish a new connection
conn.close_if_unusable_or_obsolete()
self.perform_work(body, *args)
except Exception:
logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}')
finally:
if 'uuid' in body:
uuid = body['uuid']
finished.put(uuid)
logger.debug('worker exiting gracefully pid:{}'.format(os.getpid()))
def perform_work(self, body):
raise NotImplementedError()
def on_start(self):
pass
def on_stop(self):
pass

View File

@@ -4,10 +4,12 @@ import os
import signal import signal
import time import time
import datetime import datetime
from queue import Empty as QueueEmpty
from django.conf import settings from django.conf import settings
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.timezone import now as tz_now from django.utils.timezone import now as tz_now
from django import db
from django.db import transaction, connection as django_connection from django.db import transaction, connection as django_connection
from django_guid import set_guid from django_guid import set_guid
@@ -16,6 +18,7 @@ import psutil
import redis import redis
from awx.main.utils.redis import get_redis_client from awx.main.utils.redis import get_redis_client
from awx.main.utils.db import set_connection_name
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob
from awx.main.constants import ACTIVE_STATES from awx.main.constants import ACTIVE_STATES
@@ -23,7 +26,6 @@ from awx.main.models.events import emit_event_detail
from awx.main.utils.profiling import AWXProfiler from awx.main.utils.profiling import AWXProfiler
from awx.main.tasks.system import events_processed_hook from awx.main.tasks.system import events_processed_hook
import awx.main.analytics.subsystem_metrics as s_metrics import awx.main.analytics.subsystem_metrics as s_metrics
from .base import BaseWorker
logger = logging.getLogger('awx.main.commands.run_callback_receiver') logger = logging.getLogger('awx.main.commands.run_callback_receiver')
@@ -54,7 +56,17 @@ def job_stats_wrapup(job_identifier, event=None):
logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier)) logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier))
class CallbackBrokerWorker(BaseWorker): class WorkerSignalHandler:
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class CallbackBrokerWorker:
""" """
A worker implementation that deserializes callback event data and persists A worker implementation that deserializes callback event data and persists
it into the database. it into the database.
@@ -86,7 +98,7 @@ class CallbackBrokerWorker(BaseWorker):
"""This needs to be obtained after forking, or else it will give the parent process""" """This needs to be obtained after forking, or else it will give the parent process"""
return os.getpid() return os.getpid()
def read(self, queue): def read(self):
has_redis_error = False has_redis_error = False
try: try:
res = self.redis.blpop(self.queue_name, timeout=1) res = self.redis.blpop(self.queue_name, timeout=1)
@@ -149,10 +161,37 @@ class CallbackBrokerWorker(BaseWorker):
filepath = self.prof.stop() filepath = self.prof.stop()
logger.error(f'profiling is disabled, wrote {filepath}') logger.error(f'profiling is disabled, wrote {filepath}')
def work_loop(self, *args, **kw): def work_loop(self, idx, *args):
if settings.AWX_CALLBACK_PROFILE: if settings.AWX_CALLBACK_PROFILE:
signal.signal(signal.SIGUSR1, self.toggle_profiling) signal.signal(signal.SIGUSR1, self.toggle_profiling)
return super(CallbackBrokerWorker, self).work_loop(*args, **kw)
ppid = os.getppid()
signal_handler = WorkerSignalHandler()
set_connection_name('worker') # set application_name to distinguish from other dispatcher processes
while not signal_handler.kill_now:
# if the parent PID changes, this process has been orphaned
# via e.g., segfault or sigkill, we should exit too
if os.getppid() != ppid:
break
try:
body = self.read() # this is only for the callback, only reading from redis.
if body == 'QUIT':
break
except QueueEmpty:
continue
except Exception:
logger.exception("Exception on worker {}, reconnecting: ".format(idx))
continue
try:
for conn in db.connections.all():
# If the database connection has a hiccup during the prior message, close it
# so we can establish a new connection
conn.close_if_unusable_or_obsolete()
self.perform_work(body, *args)
except Exception:
logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}')
logger.debug('worker exiting gracefully pid:{}'.format(os.getpid()))
def flush(self, force=False): def flush(self, force=False):
now = tz_now() now = tz_now()

View File

@@ -1,144 +1,49 @@
import inspect
import logging import logging
import importlib import importlib
import sys
import traceback
import time import time
from kubernetes.config import kube_config
from django.conf import settings
from django_guid import set_guid from django_guid import set_guid
from awx.main.tasks.system import dispatch_startup, inform_cluster_of_shutdown
from .base import BaseWorker
logger = logging.getLogger('awx.main.dispatch') logger = logging.getLogger('awx.main.dispatch')
class TaskWorker(BaseWorker): def resolve_callable(task):
""" """
A worker implementation that deserializes task messages and runs native Transform a dotted notation task into an imported, callable function, e.g.,
Python code. awx.main.tasks.system.delete_inventory
awx.main.tasks.jobs.RunProjectUpdate
The code that *builds* these types of messages is found in
`awx.main.dispatch.publish`.
""" """
if not task.startswith('awx.'):
raise ValueError('{} is not a valid awx task'.format(task))
module, target = task.rsplit('.', 1)
module = importlib.import_module(module)
_call = None
if hasattr(module, target):
_call = getattr(module, target, None)
if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')):
raise ValueError('{} is not decorated with @task()'.format(task))
return _call
@staticmethod
def resolve_callable(task):
"""
Transform a dotted notation task into an imported, callable function, e.g.,
awx.main.tasks.system.delete_inventory def run_callable(body):
awx.main.tasks.jobs.RunProjectUpdate """
""" Given some AMQP message, import the correct Python code and run it.
if not task.startswith('awx.'): """
raise ValueError('{} is not a valid awx task'.format(task)) task = body['task']
module, target = task.rsplit('.', 1) uuid = body.get('uuid', '<unknown>')
module = importlib.import_module(module) args = body.get('args', [])
_call = None kwargs = body.get('kwargs', {})
if hasattr(module, target): if 'guid' in body:
_call = getattr(module, target, None) set_guid(body.pop('guid'))
if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')): _call = resolve_callable(task)
raise ValueError('{} is not decorated with @task()'.format(task)) log_extra = ''
logger_method = logger.debug
return _call if 'time_pub' in body:
time_publish = time.time() - body['time_pub']
@staticmethod if time_publish > 5.0:
def run_callable(body): # If task too a very long time to process, add this information to the log
""" log_extra = f' took {time_publish:.4f} to send message'
Given some AMQP message, import the correct Python code and run it. logger_method = logger.info
""" # don't print kwargs, they often contain launch-time secrets
task = body['task'] logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')
uuid = body.get('uuid', '<unknown>') return _call(*args, **kwargs)
args = body.get('args', [])
kwargs = body.get('kwargs', {})
if 'guid' in body:
set_guid(body.pop('guid'))
_call = TaskWorker.resolve_callable(task)
if inspect.isclass(_call):
# the callable is a class, e.g., RunJob; instantiate and
# return its `run()` method
_call = _call().run
log_extra = ''
logger_method = logger.debug
if ('time_ack' in body) and ('time_pub' in body):
time_publish = body['time_ack'] - body['time_pub']
time_waiting = time.time() - body['time_ack']
if time_waiting > 5.0 or time_publish > 5.0:
# If task too a very long time to process, add this information to the log
log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher'
logger_method = logger.info
# don't print kwargs, they often contain launch-time secrets
logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')
return _call(*args, **kwargs)
def perform_work(self, body):
"""
Import and run code for a task e.g.,
body = {
'args': [8],
'callbacks': [{
'args': [],
'kwargs': {}
'task': u'awx.main.tasks.system.handle_work_success'
}],
'errbacks': [{
'args': [],
'kwargs': {},
'task': 'awx.main.tasks.system.handle_work_error'
}],
'kwargs': {},
'task': u'awx.main.tasks.jobs.RunProjectUpdate'
}
"""
settings.__clean_on_fork__()
result = None
try:
result = self.run_callable(body)
except Exception as exc:
result = exc
try:
if getattr(exc, 'is_awx_task_error', False):
# Error caused by user / tracked in job output
logger.warning("{}".format(exc))
else:
task = body['task']
args = body.get('args', [])
kwargs = body.get('kwargs', {})
logger.exception('Worker failed to run task {}(*{}, **{}'.format(task, args, kwargs))
except Exception:
# It's fairly critical that this code _not_ raise exceptions on logging
# If you configure external logging in a way that _it_ fails, there's
# not a lot we can do here; sys.stderr.write is a final hail mary
_, _, tb = sys.exc_info()
traceback.print_tb(tb)
for callback in body.get('errbacks', []) or []:
callback['uuid'] = body['uuid']
self.perform_work(callback)
finally:
# It's frustrating that we have to do this, but the python k8s
# client leaves behind cacert files in /tmp, so we must clean up
# the tmpdir per-dispatcher process every time a new task comes in
try:
kube_config._cleanup_temp_files()
except Exception:
logger.exception('failed to cleanup k8s client tmp files')
for callback in body.get('callbacks', []) or []:
callback['uuid'] = body['uuid']
self.perform_work(callback)
return result
def on_start(self):
dispatch_startup()
def on_stop(self):
inform_cluster_of_shutdown()

View File

@@ -0,0 +1,88 @@
import argparse
import inspect
import logging
import os
import sys
import yaml
from django.core.management.base import BaseCommand, CommandError
from django.db import connection
from dispatcherd.cli import (
CONTROL_ARG_SCHEMAS,
DEFAULT_CONFIG_FILE,
_base_cli_parent,
_control_common_parent,
_register_control_arguments,
_build_command_data_from_args,
)
from dispatcherd.config import setup as dispatcher_setup
from dispatcherd.factories import get_control_from_settings
from dispatcherd.service import control_tasks
from awx.main.dispatch.config import get_dispatcherd_config
from awx.main.management.commands.dispatcherd import ensure_no_dispatcherd_env_config
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = 'Dispatcher control operations'
def add_arguments(self, parser):
parser.description = 'Run dispatcherd control commands using awx-manage.'
base_parent = _base_cli_parent()
control_parent = _control_common_parent()
parser._add_container_actions(base_parent)
parser._add_container_actions(control_parent)
subparsers = parser.add_subparsers(dest='command', metavar='command')
subparsers.required = True
shared_parents = [base_parent, control_parent]
for command in control_tasks.__all__:
func = getattr(control_tasks, command, None)
doc = inspect.getdoc(func) or ''
summary = doc.splitlines()[0] if doc else None
command_parser = subparsers.add_parser(
command,
help=summary,
description=doc,
parents=shared_parents,
)
_register_control_arguments(command_parser, CONTROL_ARG_SCHEMAS.get(command))
def handle(self, *args, **options):
command = options.pop('command', None)
if not command:
raise CommandError('No dispatcher control command specified')
for django_opt in ('verbosity', 'traceback', 'no_color', 'force_color', 'skip_checks'):
options.pop(django_opt, None)
log_level = options.pop('log_level', 'DEBUG')
config_path = os.path.abspath(options.pop('config', DEFAULT_CONFIG_FILE))
expected_replies = options.pop('expected_replies', 1)
logging.basicConfig(level=getattr(logging, log_level), stream=sys.stdout)
logger.debug(f"Configured standard out logging at {log_level} level")
default_config = os.path.abspath(DEFAULT_CONFIG_FILE)
ensure_no_dispatcherd_env_config()
if config_path != default_config:
raise CommandError('The config path CLI option is not allowed for the awx-manage command')
if connection.vendor == 'sqlite':
raise CommandError('dispatcherctl is not supported with sqlite3; use a PostgreSQL database')
else:
logger.info('Using config generated from awx.main.dispatch.config.get_dispatcherd_config')
dispatcher_setup(get_dispatcherd_config())
schema_namespace = argparse.Namespace(**options)
data = _build_command_data_from_args(schema_namespace, command)
ctl = get_control_from_settings()
returned = ctl.control_with_reply(command, data=data, expected_replies=expected_replies)
self.stdout.write(yaml.dump(returned, default_flow_style=False))
if len(returned) < expected_replies:
logger.error(f'Obtained only {len(returned)} of {expected_replies}, exiting with non-zero code')
raise CommandError('dispatcherctl returned fewer replies than expected')

View File

@@ -0,0 +1,85 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved
import copy
import hashlib
import json
import logging
import logging.config
import os
from django.conf import settings
from django.core.cache import cache as django_cache
from django.core.management.base import BaseCommand, CommandError
from django.db import connection
from dispatcherd.config import setup as dispatcher_setup
from awx.main.dispatch.config import get_dispatcherd_config
logger = logging.getLogger('awx.main.dispatch')
from dispatcherd import run_service
def _json_default(value):
if isinstance(value, set):
return sorted(value)
if isinstance(value, tuple):
return list(value)
return str(value)
def _hash_config(config):
serialized = json.dumps(config, sort_keys=True, separators=(',', ':'), default=_json_default)
return hashlib.sha256(serialized.encode('utf-8')).hexdigest()
def ensure_no_dispatcherd_env_config():
if os.getenv('DISPATCHERD_CONFIG_FILE'):
raise CommandError('DISPATCHERD_CONFIG_FILE is set but awx-manage dispatcherd uses dynamic config from code')
class Command(BaseCommand):
help = (
'Run the background task service, this is the supported entrypoint since the introduction of dispatcherd as a library. '
'This replaces the prior awx-manage run_dispatcher service, and control actions are at awx-manage dispatcherctl.'
)
def add_arguments(self, parser):
return
def handle(self, *arg, **options):
ensure_no_dispatcherd_env_config()
self.configure_dispatcher_logging()
config = get_dispatcherd_config(for_service=True)
config_hash = _hash_config(config)
logger.info(
'Using dispatcherd config generated from awx.main.dispatch.config.get_dispatcherd_config (sha256=%s)',
config_hash,
)
# Close the connection, because the pg_notify broker will create new async connection
connection.close()
django_cache.close()
dispatcher_setup(config)
run_service()
def configure_dispatcher_logging(self):
# Apply special log rule for the parent process
special_logging = copy.deepcopy(settings.LOGGING)
changed_handlers = []
for handler_name, handler_config in special_logging.get('handlers', {}).items():
filters = handler_config.get('filters', [])
if 'dynamic_level_filter' in filters:
handler_config['filters'] = [flt for flt in filters if flt != 'dynamic_level_filter']
changed_handlers.append(handler_name)
logger.info(f'Dispatcherd main process replaced log level filter for handlers: {changed_handlers}')
# Apply the custom logging level here, before the asyncio code starts
special_logging.setdefault('loggers', {}).setdefault('dispatcherd', {})
special_logging['loggers']['dispatcherd']['level'] = settings.LOG_AGGREGATOR_LEVEL
logging.config.dictConfig(special_logging)

View File

@@ -4,7 +4,7 @@ import json
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from awx.main.dispatch import pg_bus_conn from awx.main.dispatch import pg_bus_conn
from awx.main.dispatch.worker.task import TaskWorker from awx.main.dispatch.worker.task import run_callable
logger = logging.getLogger('awx.main.cache_clear') logger = logging.getLogger('awx.main.cache_clear')
@@ -21,11 +21,11 @@ class Command(BaseCommand):
try: try:
with pg_bus_conn() as conn: with pg_bus_conn() as conn:
conn.listen("tower_settings_change") conn.listen("tower_settings_change")
for e in conn.events(yield_timeouts=True): for e in conn.events():
if e is not None: if e is not None:
body = json.loads(e.payload) body = json.loads(e.payload)
logger.info(f"Cache clear request received. Clearing now, payload: {e.payload}") logger.info(f"Cache clear request received. Clearing now, payload: {e.payload}")
TaskWorker.run_callable(body) run_callable(body)
except Exception: except Exception:
# Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata

View File

@@ -3,13 +3,12 @@
import redis import redis
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
import redis.exceptions import redis.exceptions
from awx.main.analytics.subsystem_metrics import CallbackReceiverMetricsServer from awx.main.analytics.subsystem_metrics import CallbackReceiverMetricsServer
from awx.main.dispatch.control import Control
from awx.main.dispatch.worker import AWXConsumerRedis, CallbackBrokerWorker from awx.main.dispatch.worker import AWXConsumerRedis, CallbackBrokerWorker
from awx.main.utils.redis import get_redis_client
class Command(BaseCommand): class Command(BaseCommand):
@@ -26,7 +25,7 @@ class Command(BaseCommand):
def handle(self, *arg, **options): def handle(self, *arg, **options):
if options.get('status'): if options.get('status'):
print(Control('callback_receiver').status()) print(self.status())
return return
consumer = None consumer = None
@@ -36,13 +35,16 @@ class Command(BaseCommand):
raise CommandError(f'Callback receiver could not connect to redis, error: {exc}') raise CommandError(f'Callback receiver could not connect to redis, error: {exc}')
try: try:
consumer = AWXConsumerRedis( consumer = AWXConsumerRedis('callback_receiver', CallbackBrokerWorker())
'callback_receiver',
CallbackBrokerWorker(),
queues=[getattr(settings, 'CALLBACK_QUEUE', '')],
)
consumer.run() consumer.run()
except KeyboardInterrupt: except KeyboardInterrupt:
print('Terminating Callback Receiver') print('Terminating Callback Receiver')
if consumer: if consumer:
consumer.stop() consumer.stop()
def status(self, *args, **kwargs):
r = get_redis_client()
workers = []
for key in r.keys('awx_callback_receiver_statistics_*'):
workers.append(r.get(key).decode('utf-8'))
return '\n'.join(workers)

View File

@@ -1,46 +1,24 @@
# Copyright (c) 2015 Ansible, Inc. # Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved. # All Rights Reserved.
import logging import logging
import logging.config
import yaml import yaml
import copy
import redis from django.core.management.base import CommandError
from django.conf import settings
from django.db import connection
from django.core.management.base import BaseCommand, CommandError
from django.core.cache import cache as django_cache
from flags.state import flag_enabled
from dispatcherd.factories import get_control_from_settings from dispatcherd.factories import get_control_from_settings
from dispatcherd import run_service
from dispatcherd.config import setup as dispatcher_setup
from awx.main.dispatch import get_task_queuename from awx.main.management.commands.dispatcherd import Command as DispatcherdCommand
from awx.main.dispatch.config import get_dispatcherd_config
from awx.main.dispatch.control import Control
from awx.main.dispatch.pool import AutoscalePool
from awx.main.dispatch.worker import AWXConsumerPG, TaskWorker
from awx.main.analytics.subsystem_metrics import DispatcherMetricsServer
logger = logging.getLogger('awx.main.dispatch') logger = logging.getLogger('awx.main.dispatch')
class Command(BaseCommand): class Command(DispatcherdCommand):
help = 'Launch the task dispatcher' help = 'Launch the task dispatcher (deprecated; use awx-manage dispatcherd)'
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--status', dest='status', action='store_true', help='print the internal state of any running dispatchers') parser.add_argument('--status', dest='status', action='store_true', help='print the internal state of any running dispatchers')
parser.add_argument('--schedule', dest='schedule', action='store_true', help='print the current status of schedules being ran by dispatcher')
parser.add_argument('--running', dest='running', action='store_true', help='print the UUIDs of any tasked managed by this dispatcher') parser.add_argument('--running', dest='running', action='store_true', help='print the UUIDs of any tasked managed by this dispatcher')
parser.add_argument(
'--reload',
dest='reload',
action='store_true',
help=('cause the dispatcher to recycle all of its worker processes; running jobs will run to completion first'),
)
parser.add_argument( parser.add_argument(
'--cancel', '--cancel',
dest='cancel', dest='cancel',
@@ -50,41 +28,22 @@ class Command(BaseCommand):
'Only running tasks can be canceled, queued tasks must be started before they can be canceled.' 'Only running tasks can be canceled, queued tasks must be started before they can be canceled.'
), ),
) )
super().add_arguments(parser)
def handle(self, *arg, **options): def handle(self, *args, **options):
logger.warning('awx-manage run_dispatcher is deprecated; use awx-manage dispatcherd')
if options.get('status'): if options.get('status'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): ctl = get_control_from_settings()
ctl = get_control_from_settings() running_data = ctl.control_with_reply('status')
running_data = ctl.control_with_reply('status') if len(running_data) != 1:
if len(running_data) != 1: raise CommandError('Did not receive expected number of replies')
raise CommandError('Did not receive expected number of replies') print(yaml.dump(running_data[0], default_flow_style=False))
print(yaml.dump(running_data[0], default_flow_style=False))
return
else:
print(Control('dispatcher').status())
return
if options.get('schedule'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
print('NOT YET IMPLEMENTED')
return
else:
print(Control('dispatcher').schedule())
return return
if options.get('running'): if options.get('running'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): ctl = get_control_from_settings()
ctl = get_control_from_settings() running_data = ctl.control_with_reply('running')
running_data = ctl.control_with_reply('running') print(yaml.dump(running_data, default_flow_style=False))
print(yaml.dump(running_data, default_flow_style=False)) return
return
else:
print(Control('dispatcher').running())
return
if options.get('reload'):
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
print('NOT YET IMPLEMENTED')
return
else:
return Control('dispatcher').control({'control': 'reload'})
if options.get('cancel'): if options.get('cancel'):
cancel_str = options.get('cancel') cancel_str = options.get('cancel')
try: try:
@@ -94,56 +53,12 @@ class Command(BaseCommand):
if not isinstance(cancel_data, list): if not isinstance(cancel_data, list):
cancel_data = [cancel_str] cancel_data = [cancel_str]
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): ctl = get_control_from_settings()
ctl = get_control_from_settings() results = []
results = [] for task_id in cancel_data:
for task_id in cancel_data: # For each task UUID, send an individual cancel command
# For each task UUID, send an individual cancel command result = ctl.control_with_reply('cancel', data={'uuid': task_id})
result = ctl.control_with_reply('cancel', data={'uuid': task_id}) results.append(result)
results.append(result) print(yaml.dump(results, default_flow_style=False))
print(yaml.dump(results, default_flow_style=False)) return
return return super().handle(*args, **options)
else:
print(Control('dispatcher').cancel(cancel_data))
return
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
self.configure_dispatcher_logging()
# Close the connection, because the pg_notify broker will create new async connection
connection.close()
django_cache.close()
dispatcher_setup(get_dispatcherd_config(for_service=True))
run_service()
else:
consumer = None
try:
DispatcherMetricsServer().start()
except redis.exceptions.ConnectionError as exc:
raise CommandError(f'Dispatcher could not connect to redis, error: {exc}')
try:
queues = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()]
consumer = AWXConsumerPG('dispatcher', TaskWorker(), queues, AutoscalePool(min_workers=4), schedule=settings.CELERYBEAT_SCHEDULE)
consumer.run()
except KeyboardInterrupt:
logger.debug('Terminating Task Dispatcher')
if consumer:
consumer.stop()
def configure_dispatcher_logging(self):
# Apply special log rule for the parent process
special_logging = copy.deepcopy(settings.LOGGING)
for handler_name, handler_config in special_logging.get('handlers', {}).items():
filters = handler_config.get('filters', [])
if 'dynamic_level_filter' in filters:
handler_config['filters'] = [flt for flt in filters if flt != 'dynamic_level_filter']
logger.info(f'Dispatcherd main process replaced log level filter for {handler_name} handler')
# Apply the custom logging level here, before the asyncio code starts
special_logging.setdefault('loggers', {}).setdefault('dispatcherd', {})
special_logging['loggers']['dispatcherd']['level'] = settings.LOG_AGGREGATOR_LEVEL
logging.config.dictConfig(special_logging)

View File

@@ -5,7 +5,7 @@ from django.core.management.base import BaseCommand
from django.conf import settings from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from awx.main.dispatch import pg_bus_conn from awx.main.dispatch import pg_bus_conn
from awx.main.dispatch.worker.task import TaskWorker from awx.main.dispatch.worker.task import run_callable
from awx.main.utils.external_logging import reconfigure_rsyslog from awx.main.utils.external_logging import reconfigure_rsyslog
logger = logging.getLogger('awx.main.rsyslog_configurer') logger = logging.getLogger('awx.main.rsyslog_configurer')
@@ -26,7 +26,7 @@ class Command(BaseCommand):
conn.listen("rsyslog_configurer") conn.listen("rsyslog_configurer")
# reconfigure rsyslog on start up # reconfigure rsyslog on start up
reconfigure_rsyslog() reconfigure_rsyslog()
for e in conn.events(yield_timeouts=True): for e in conn.events():
if e is not None: if e is not None:
logger.info("Change in logging settings found. Restarting rsyslogd") logger.info("Change in logging settings found. Restarting rsyslogd")
# clear the cache of relevant settings then restart # clear the cache of relevant settings then restart
@@ -34,7 +34,7 @@ class Command(BaseCommand):
cache.delete_many(setting_keys) cache.delete_many(setting_keys)
settings._awx_conf_memoizedcache.clear() settings._awx_conf_memoizedcache.clear()
body = json.loads(e.payload) body = json.loads(e.payload)
TaskWorker.run_callable(body) run_callable(body)
except Exception: except Exception:
# Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata
logger.exception('Encountered unhandled error in rsyslog_configurer main loop') logger.exception('Encountered unhandled error in rsyslog_configurer main loop')

View File

@@ -21,6 +21,6 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RunPython(setup_tower_managed_defaults), migrations.RunPython(setup_tower_managed_defaults, migrations.RunPython.noop),
migrations.RunPython(setup_rbac_role_system_administrator), migrations.RunPython(setup_rbac_role_system_administrator, migrations.RunPython.noop),
] ]

View File

@@ -98,5 +98,5 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RunPython(convert_controller_role_definitions), migrations.RunPython(convert_controller_role_definitions, migrations.RunPython.noop),
] ]

View File

@@ -3,19 +3,15 @@ from django.db import migrations, models
from awx.main.migrations._create_system_jobs import delete_clear_tokens_sjt from awx.main.migrations._create_system_jobs import delete_clear_tokens_sjt
# --- START of function merged from 0203_rename_github_app_kind.py ---
def update_github_app_kind(apps, schema_editor): def update_github_app_kind(apps, schema_editor):
""" """
Updates the 'kind' field for CredentialType records Updates the 'namespace' field for CredentialType records
from 'github_app' to 'github_app_lookup'. from 'github_app' to 'github_app_lookup'.
This addresses a change in the entry point key for the GitHub App plugin. This addresses a change in the entry point key for the GitHub App plugin.
""" """
CredentialType = apps.get_model('main', 'CredentialType') CredentialType = apps.get_model('main', 'CredentialType')
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
CredentialType.objects.using(db_alias).filter(kind='github_app').update(kind='github_app_lookup') CredentialType.objects.using(db_alias).filter(namespace='github_app').update(namespace='github_app_lookup')
# --- END of function merged from 0203_rename_github_app_kind.py ---
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -118,7 +114,5 @@ class Migration(migrations.Migration):
max_length=32, max_length=32,
), ),
), ),
# --- START of operations merged from 0203_rename_github_app_kind.py ---
migrations.RunPython(update_github_app_kind, migrations.RunPython.noop), migrations.RunPython(update_github_app_kind, migrations.RunPython.noop),
# --- END of operations merged from 0203_rename_github_app_kind.py ---
] ]

View File

@@ -386,7 +386,6 @@ class gce(PluginFileInjector):
# auth related items # auth related items
ret['auth_kind'] = "serviceaccount" ret['auth_kind'] = "serviceaccount"
filters = []
# TODO: implement gce group_by options # TODO: implement gce group_by options
# gce never processed the group_by field, if it had, we would selectively # gce never processed the group_by field, if it had, we would selectively
# apply those options here, but it did not, so all groups are added here # apply those options here, but it did not, so all groups are added here
@@ -420,8 +419,6 @@ class gce(PluginFileInjector):
if keyed_groups: if keyed_groups:
ret['keyed_groups'] = keyed_groups ret['keyed_groups'] = keyed_groups
if filters:
ret['filters'] = filters
if compose_dict: if compose_dict:
ret['compose'] = compose_dict ret['compose'] = compose_dict
if inventory_source.source_regions and 'all' not in inventory_source.source_regions: if inventory_source.source_regions and 'all' not in inventory_source.source_regions:

View File

@@ -315,12 +315,11 @@ class PrimordialModel(HasEditsMixin, CreatedModifiedModel):
) )
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
r = super(PrimordialModel, self).__init__(*args, **kwargs) super(PrimordialModel, self).__init__(*args, **kwargs)
if self.pk: if self.pk:
self._prior_values_store = self._get_fields_snapshot() self._prior_values_store = self._get_fields_snapshot()
else: else:
self._prior_values_store = {} self._prior_values_store = {}
return r
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
update_fields = kwargs.get('update_fields', []) update_fields = kwargs.get('update_fields', [])

View File

@@ -50,9 +50,8 @@ class HasPolicyEditsMixin(HasEditsMixin):
abstract = True abstract = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
r = super(BaseModel, self).__init__(*args, **kwargs) super(BaseModel, self).__init__(*args, **kwargs)
self._prior_values_store = self._get_fields_snapshot() self._prior_values_store = self._get_fields_snapshot()
return r
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
super(BaseModel, self).save(*args, **kwargs) super(BaseModel, self).save(*args, **kwargs)

View File

@@ -188,6 +188,16 @@ class SurveyJobTemplateMixin(models.Model):
runtime_extra_vars.pop(variable_key) runtime_extra_vars.pop(variable_key)
if default is not None: if default is not None:
# do not add variables that contain an empty string, are not required and are not present in extra_vars
# password fields must be skipped, because default values have special behaviour
if (
default == ''
and not survey_element.get('required')
and survey_element.get('type') != 'password'
and variable_key not in runtime_extra_vars
):
continue
decrypted_default = default decrypted_default = default
if survey_element['type'] == "password" and isinstance(decrypted_default, str) and decrypted_default.startswith('$encrypted$'): if survey_element['type'] == "password" and isinstance(decrypted_default, str) and decrypted_default.startswith('$encrypted$'):
decrypted_default = decrypt_value(get_encryption_key('value', pk=None), decrypted_default) decrypted_default = decrypt_value(get_encryption_key('value', pk=None), decrypted_default)

View File

@@ -10,11 +10,13 @@ import json
import logging import logging
import os import os
import re import re
import socket
import subprocess import subprocess
import tempfile import tempfile
from collections import OrderedDict from collections import OrderedDict
# Dispatcher
from dispatcherd.factories import get_control_from_settings
# Django # Django
from django.conf import settings from django.conf import settings
from django.db import models, connection, transaction from django.db import models, connection, transaction
@@ -24,7 +26,6 @@ from django.utils.translation import gettext_lazy as _
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.encoding import smart_str from django.utils.encoding import smart_str
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from flags.state import flag_enabled
# REST Framework # REST Framework
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
@@ -39,7 +40,6 @@ from ansible_base.rbac.models import RoleEvaluation
# AWX # AWX
from awx.main.models.base import CommonModelNameNotUnique, PasswordFieldsModel, NotificationFieldsModel from awx.main.models.base import CommonModelNameNotUnique, PasswordFieldsModel, NotificationFieldsModel
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.control import Control as ControlDispatcher
from awx.main.registrar import activity_stream_registrar from awx.main.registrar import activity_stream_registrar
from awx.main.models.mixins import TaskManagerUnifiedJobMixin, ExecutionEnvironmentMixin from awx.main.models.mixins import TaskManagerUnifiedJobMixin, ExecutionEnvironmentMixin
from awx.main.models.rbac import to_permissions from awx.main.models.rbac import to_permissions
@@ -918,7 +918,7 @@ class UnifiedJob(
# If we have a start and finished time, and haven't already calculated # If we have a start and finished time, and haven't already calculated
# out the time that elapsed, do so. # out the time that elapsed, do so.
if self.started and self.finished and self.elapsed == 0.0: if self.started and self.finished and self.elapsed == decimal.Decimal(0):
td = self.finished - self.started td = self.finished - self.started
elapsed = decimal.Decimal(td.total_seconds()) elapsed = decimal.Decimal(td.total_seconds())
self.elapsed = elapsed.quantize(dq) self.elapsed = elapsed.quantize(dq)
@@ -1354,8 +1354,6 @@ class UnifiedJob(
status_data['instance_group_name'] = None status_data['instance_group_name'] = None
elif status in ['successful', 'failed', 'canceled'] and self.finished: elif status in ['successful', 'failed', 'canceled'] and self.finished:
status_data['finished'] = datetime.datetime.strftime(self.finished, "%Y-%m-%dT%H:%M:%S.%fZ") status_data['finished'] = datetime.datetime.strftime(self.finished, "%Y-%m-%dT%H:%M:%S.%fZ")
elif status == 'running':
status_data['started'] = datetime.datetime.strftime(self.finished, "%Y-%m-%dT%H:%M:%S.%fZ")
status_data.update(self.websocket_emit_data()) status_data.update(self.websocket_emit_data())
status_data['group_name'] = 'jobs' status_data['group_name'] = 'jobs'
if getattr(self, 'unified_job_template_id', None): if getattr(self, 'unified_job_template_id', None):
@@ -1487,53 +1485,17 @@ class UnifiedJob(
return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id) return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id)
return None return None
def fallback_cancel(self):
if not self.celery_task_id:
self.refresh_from_db(fields=['celery_task_id'])
self.cancel_dispatcher_process()
def cancel_dispatcher_process(self): def cancel_dispatcher_process(self):
"""Returns True if dispatcher running this job acknowledged request and sent SIGTERM""" """Returns True if dispatcher running this job acknowledged request and sent SIGTERM"""
if not self.celery_task_id: if not self.celery_task_id:
return False return False
canceled = []
# Special case for task manager (used during workflow job cancellation)
if not connection.get_autocommit():
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
try:
from dispatcherd.factories import get_control_from_settings
ctl = get_control_from_settings()
ctl.control('cancel', data={'uuid': self.celery_task_id})
except Exception:
logger.exception("Error sending cancel command to new dispatcher")
else:
try:
ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id], with_reply=False)
except Exception:
logger.exception("Error sending cancel command to legacy dispatcher")
return True # task manager itself needs to act under assumption that cancel was received
# Standard case with reply
try: try:
timeout = 5 logger.info(f'Sending cancel message to pg_notify channel {self.controller_node} for task {self.celery_task_id}')
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): ctl = get_control_from_settings(default_publish_channel=self.controller_node)
from dispatcherd.factories import get_control_from_settings ctl.control('cancel', data={'uuid': self.celery_task_id})
ctl = get_control_from_settings()
results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout)
# Check if cancel was successful by checking if we got any results
return bool(results and len(results) > 0)
else:
# Original implementation
canceled = ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id])
except socket.timeout:
logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s')
except Exception: except Exception:
logger.exception("error encountered when checking task status") logger.exception("Error sending cancel command to dispatcher")
return bool(self.celery_task_id in canceled) # True or False, whether confirmation was obtained
def cancel(self, job_explanation=None, is_chain=False): def cancel(self, job_explanation=None, is_chain=False):
if self.can_cancel: if self.can_cancel:
@@ -1556,19 +1518,13 @@ class UnifiedJob(
# the job control process will use the cancel_flag to distinguish a shutdown from a cancel # the job control process will use the cancel_flag to distinguish a shutdown from a cancel
self.save(update_fields=cancel_fields) self.save(update_fields=cancel_fields)
controller_notified = False # Be extra sure we have the task id, in case job is transitioning into running right now
if self.celery_task_id: if not self.celery_task_id:
controller_notified = self.cancel_dispatcher_process() self.refresh_from_db(fields=['celery_task_id', 'controller_node'])
# If a SIGTERM signal was sent to the control process, and acked by the dispatcher # send pg_notify message to cancel, will not send until transaction completes
# then we want to let its own cleanup change status, otherwise change status now if self.celery_task_id:
if not controller_notified: self.cancel_dispatcher_process()
if self.status != 'canceled':
self.status = 'canceled'
self.save(update_fields=['status'])
# Avoid race condition where we have stale model from pending state but job has already started,
# its checking signal but not cancel_flag, so re-send signal after updating cancel fields
self.fallback_cancel()
return self.cancel_flag return self.cancel_flag

View File

@@ -785,7 +785,7 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio
def cancel_dispatcher_process(self): def cancel_dispatcher_process(self):
# WorkflowJobs don't _actually_ run anything in the dispatcher, so # WorkflowJobs don't _actually_ run anything in the dispatcher, so
# there's no point in asking the dispatcher if it knows about this task # there's no point in asking the dispatcher if it knows about this task
return True return
class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin): class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin):

View File

@@ -76,10 +76,12 @@ class GrafanaBackend(AWXBaseEmailBackend, CustomNotificationBase):
grafana_headers = {} grafana_headers = {}
if 'started' in m.body: if 'started' in m.body:
try: try:
epoch = datetime.datetime.utcfromtimestamp(0) epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc)
grafana_data['time'] = grafana_data['timeEnd'] = int((dp.parse(m.body['started']).replace(tzinfo=None) - epoch).total_seconds() * 1000) grafana_data['time'] = grafana_data['timeEnd'] = int(
(dp.parse(m.body['started']).replace(tzinfo=datetime.timezone.utc) - epoch).total_seconds() * 1000
)
if m.body.get('finished'): if m.body.get('finished'):
grafana_data['timeEnd'] = int((dp.parse(m.body['finished']).replace(tzinfo=None) - epoch).total_seconds() * 1000) grafana_data['timeEnd'] = int((dp.parse(m.body['finished']).replace(tzinfo=datetime.timezone.utc) - epoch).total_seconds() * 1000)
except ValueError: except ValueError:
logger.error(smart_str(_("Error converting time {} or timeEnd {} to int.").format(m.body['started'], m.body['finished']))) logger.error(smart_str(_("Error converting time {} or timeEnd {} to int.").format(m.body['started'], m.body['finished'])))
if not self.fail_silently: if not self.fail_silently:

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2016 Ansible, Inc. # Copyright (c) 2016 Ansible, Inc.
# All Rights Reserved. # All Rights Reserved.
import base64
import json import json
import logging import logging
import requests import requests
@@ -84,20 +85,25 @@ class WebhookBackend(AWXBaseEmailBackend, CustomNotificationBase):
if resp.status_code not in [301, 307]: if resp.status_code not in [301, 307]:
break break
# convert the url to a base64 encoded string for safe logging
url_log_safe = base64.b64encode(url.encode('UTF-8'))
# get the next URL to try
url_next = resp.headers.get("Location", None)
url_next_log_safe = base64.b64encode(url_next.encode('UTF-8')) if url_next else b'None'
# we've hit a redirect. extract the redirect URL out of the first response header and try again # we've hit a redirect. extract the redirect URL out of the first response header and try again
logger.warning( logger.warning(f"Received a {resp.status_code} from {url_log_safe}, trying to reach redirect url {url_next_log_safe}; attempt #{retries+1}")
f"Received a {resp.status_code} from {url}, trying to reach redirect url {resp.headers.get('Location', None)}; attempt #{retries+1}"
)
# take the first redirect URL in the response header and try that # take the first redirect URL in the response header and try that
url = resp.headers.get("Location", None) url = url_next
if url is None: if url is None:
err = f"Webhook notification received redirect to a blank URL from {url}. Response headers={resp.headers}" err = f"Webhook notification received redirect to a blank URL from {url_log_safe}. Response headers={resp.headers}"
break break
else: else:
# no break condition in the loop encountered; therefore we have hit the maximum number of retries # no break condition in the loop encountered; therefore we have hit the maximum number of retries
err = f"Webhook notification max number of retries [{self.MAX_RETRIES}] exceeded. Failed to send webhook notification to {url}" err = f"Webhook notification max number of retries [{self.MAX_RETRIES}] exceeded. Failed to send webhook notification to {url_log_safe}"
if resp.status_code >= 400: if resp.status_code >= 400:
err = f"Error sending webhook notification: {resp.status_code}" err = f"Error sending webhook notification: {resp.status_code}"

View File

@@ -19,9 +19,6 @@ from django.utils.timezone import now as tz_now
from django.conf import settings from django.conf import settings
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
# django-flags
from flags.state import flag_enabled
from ansible_base.lib.utils.models import get_type_for_model from ansible_base.lib.utils.models import get_type_for_model
# django-ansible-base # django-ansible-base
@@ -523,19 +520,7 @@ class TaskManager(TaskBase):
task.save() task.save()
task.log_lifecycle("waiting") task.log_lifecycle("waiting")
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): self.control_nodes_to_notify.add(task.get_queue_name())
self.control_nodes_to_notify.add(task.get_queue_name())
else:
# apply_async does a NOTIFY to the channel dispatcher is listening to
# postgres will treat this as part of the transaction, which is what we want
if task.status != 'failed' and type(task) is not WorkflowJob:
task_cls = task._get_task_class()
task_cls.apply_async(
[task.pk],
opts,
queue=task.get_queue_name(),
uuid=task.celery_task_id,
)
# In exception cases, like a job failing pre-start checks, we send the websocket status message. # In exception cases, like a job failing pre-start checks, we send the websocket status message.
# For jobs going into waiting, we omit this because of performance issues, as it should go to running quickly # For jobs going into waiting, we omit this because of performance issues, as it should go to running quickly
@@ -729,7 +714,6 @@ class TaskManager(TaskBase):
for workflow_approval in self.get_expired_workflow_approvals(): for workflow_approval in self.get_expired_workflow_approvals():
self.timeout_approval_node(workflow_approval) self.timeout_approval_node(workflow_approval)
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): for controller_node in self.control_nodes_to_notify:
for controller_node in self.control_nodes_to_notify: logger.info(f'Notifying node {controller_node} of new waiting jobs.')
logger.info(f'Notifying node {controller_node} of new waiting jobs.') dispatch_waiting_jobs.apply_async(queue=controller_node)
dispatch_waiting_jobs.apply_async(queue=controller_node)

View File

@@ -4,10 +4,12 @@ import logging
# Django # Django
from django.conf import settings from django.conf import settings
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx import MODE from awx import MODE
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler') logger = logging.getLogger('awx.main.scheduler')
@@ -20,16 +22,16 @@ def run_manager(manager, prefix):
manager().schedule() manager().schedule()
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def task_manager(): def task_manager():
run_manager(TaskManager, "task") run_manager(TaskManager, "task")
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def dependency_manager(): def dependency_manager():
run_manager(DependencyManager, "dependency") run_manager(DependencyManager, "dependency")
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def workflow_manager(): def workflow_manager():
run_manager(WorkflowManager, "workflow") run_manager(WorkflowManager, "workflow")

View File

@@ -12,7 +12,7 @@ from django.db import transaction
# Django flags # Django flags
from flags.state import flag_enabled from flags.state import flag_enabled
from awx.main.dispatch.publish import task from dispatcherd.publish import task
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit
from awx.main.models.event_query import EventQuery from awx.main.models.event_query import EventQuery

View File

@@ -6,8 +6,8 @@ from django.conf import settings
from django.db.models import Count, F from django.db.models import Count, F
from django.db.models.functions import TruncMonth from django.db.models.functions import TruncMonth
from django.utils.timezone import now from django.utils.timezone import now
from dispatcherd.publish import task
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as task_awx
from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly
from awx.main.tasks.helpers import is_run_threshold_reached from awx.main.tasks.helpers import is_run_threshold_reached
from awx.conf.license import get_license from awx.conf.license import get_license
@@ -17,7 +17,7 @@ from awx.main.utils.db import bulk_update_sorted_by_id
logger = logging.getLogger('awx.main.tasks.host_metrics') logger = logging.getLogger('awx.main.tasks.host_metrics')
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def cleanup_host_metrics(): def cleanup_host_metrics():
if is_run_threshold_reached(getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', None), getattr(settings, 'CLEANUP_HOST_METRICS_INTERVAL', 30) * 86400): if is_run_threshold_reached(getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', None), getattr(settings, 'CLEANUP_HOST_METRICS_INTERVAL', 30) * 86400):
logger.info(f"Executing cleanup_host_metrics, last ran at {getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', '---')}") logger.info(f"Executing cleanup_host_metrics, last ran at {getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', '---')}")
@@ -28,7 +28,7 @@ def cleanup_host_metrics():
logger.info("Finished cleanup_host_metrics") logger.info("Finished cleanup_host_metrics")
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def host_metric_summary_monthly(): def host_metric_summary_monthly():
"""Run cleanup host metrics summary monthly task each week""" """Run cleanup host metrics summary monthly task each week"""
if is_run_threshold_reached(getattr(settings, 'HOST_METRIC_SUMMARY_TASK_LAST_TS', None), getattr(settings, 'HOST_METRIC_SUMMARY_TASK_INTERVAL', 7) * 86400): if is_run_threshold_reached(getattr(settings, 'HOST_METRIC_SUMMARY_TASK_LAST_TS', None), getattr(settings, 'HOST_METRIC_SUMMARY_TASK_INTERVAL', 7) * 86400):

View File

@@ -36,7 +36,6 @@ from dispatcherd.publish import task
from dispatcherd.utils import serialize_task from dispatcherd.utils import serialize_task
# AWX # AWX
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.constants import ( from awx.main.constants import (
PRIVILEGE_ESCALATION_METHODS, PRIVILEGE_ESCALATION_METHODS,
@@ -85,6 +84,7 @@ from awx.main.utils.common import (
create_partition, create_partition,
ScheduleWorkflowManager, ScheduleWorkflowManager,
ScheduleTaskManager, ScheduleTaskManager,
getattr_dne,
) )
from awx.conf.license import get_license from awx.conf.license import get_license
from awx.main.utils.handlers import SpecialInventoryHandler from awx.main.utils.handlers import SpecialInventoryHandler
@@ -93,9 +93,76 @@ from awx.main.utils.update_model import update_model
# Django flags # Django flags
from flags.state import flag_enabled from flags.state import flag_enabled
# Workload Identity
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
logger = logging.getLogger('awx.main.tasks.jobs') logger = logging.getLogger('awx.main.tasks.jobs')
def populate_claims_for_workload(unified_job) -> dict:
"""
Extract JWT claims from a Controller workload for the aap_controller_automation_job scope.
"""
# Related objects in the UnifiedJob model, applies to all job types
organization = getattr_dne(unified_job, 'organization')
ujt = getattr_dne(unified_job, 'unified_job_template')
instance_group = getattr_dne(unified_job, 'instance_group')
claims = {
AutomationControllerJobScope.CLAIM_JOB_ID: unified_job.id,
AutomationControllerJobScope.CLAIM_JOB_NAME: unified_job.name,
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: unified_job.launch_type,
}
# Related objects in the UnifiedJob model, applies to all job types
# null cases are omitted because of OIDC
if organization := getattr_dne(unified_job, 'organization'):
claims[AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME] = organization.name
claims[AutomationControllerJobScope.CLAIM_ORGANIZATION_ID] = organization.id
if ujt := getattr_dne(unified_job, 'unified_job_template'):
claims[AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME] = ujt.name
claims[AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID] = ujt.id
if instance_group := getattr_dne(unified_job, 'instance_group'):
claims[AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME] = instance_group.name
claims[AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID] = instance_group.id
# Related objects on concrete models, may not be valid for type of unified_job
if inventory := getattr_dne(unified_job, 'inventory', None):
claims[AutomationControllerJobScope.CLAIM_INVENTORY_NAME] = inventory.name
claims[AutomationControllerJobScope.CLAIM_INVENTORY_ID] = inventory.id
if execution_environment := getattr_dne(unified_job, 'execution_environment', None):
claims[AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME] = execution_environment.name
claims[AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID] = execution_environment.id
if project := getattr_dne(unified_job, 'project', None):
claims[AutomationControllerJobScope.CLAIM_PROJECT_NAME] = project.name
claims[AutomationControllerJobScope.CLAIM_PROJECT_ID] = project.id
if jt := getattr_dne(unified_job, 'job_template', None):
claims[AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME] = jt.name
claims[AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID] = jt.id
# Only valid for job templates
if hasattr(unified_job, 'playbook'):
claims[AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME] = unified_job.playbook
# Not valid for inventory updates and system jobs
if hasattr(unified_job, 'job_type'):
claims[AutomationControllerJobScope.CLAIM_JOB_TYPE] = unified_job.job_type
launched_by: dict = unified_job.launched_by
if 'name' in launched_by:
claims[AutomationControllerJobScope.CLAIM_LAUNCHED_BY_NAME] = launched_by['name']
if 'id' in launched_by:
claims[AutomationControllerJobScope.CLAIM_LAUNCHED_BY_ID] = launched_by['id']
return claims
def with_path_cleanup(f): def with_path_cleanup(f):
@functools.wraps(f) @functools.wraps(f)
def _wrapped(self, *args, **kwargs): def _wrapped(self, *args, **kwargs):
@@ -851,7 +918,7 @@ class SourceControlMixin(BaseTask):
self.release_lock(project) self.release_lock(project)
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunJob(SourceControlMixin, BaseTask): class RunJob(SourceControlMixin, BaseTask):
""" """
Run a job using ansible-playbook. Run a job using ansible-playbook.
@@ -1174,7 +1241,7 @@ class RunJob(SourceControlMixin, BaseTask):
update_inventory_computed_fields.delay(inventory.id) update_inventory_computed_fields.delay(inventory.id)
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunProjectUpdate(BaseTask): class RunProjectUpdate(BaseTask):
model = ProjectUpdate model = ProjectUpdate
event_model = ProjectUpdateEvent event_model = ProjectUpdateEvent
@@ -1329,7 +1396,6 @@ class RunProjectUpdate(BaseTask):
'local_path': os.path.basename(project_update.project.local_path), 'local_path': os.path.basename(project_update.project.local_path),
'project_path': project_update.get_project_path(check_if_exists=False), # deprecated 'project_path': project_update.get_project_path(check_if_exists=False), # deprecated
'insights_url': settings.INSIGHTS_URL_BASE, 'insights_url': settings.INSIGHTS_URL_BASE,
'oidc_endpoint': settings.INSIGHTS_OIDC_ENDPOINT,
'awx_license_type': get_license().get('license_type', 'UNLICENSED'), 'awx_license_type': get_license().get('license_type', 'UNLICENSED'),
'awx_version': get_awx_version(), 'awx_version': get_awx_version(),
'scm_url': scm_url, 'scm_url': scm_url,
@@ -1513,7 +1579,7 @@ class RunProjectUpdate(BaseTask):
return [] return []
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunInventoryUpdate(SourceControlMixin, BaseTask): class RunInventoryUpdate(SourceControlMixin, BaseTask):
model = InventoryUpdate model = InventoryUpdate
event_model = InventoryUpdateEvent event_model = InventoryUpdateEvent
@@ -1776,7 +1842,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc()) raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc())
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunAdHocCommand(BaseTask): class RunAdHocCommand(BaseTask):
""" """
Run an ad hoc command using ansible. Run an ad hoc command using ansible.
@@ -1929,7 +1995,7 @@ class RunAdHocCommand(BaseTask):
return d return d
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunSystemJob(BaseTask): class RunSystemJob(BaseTask):
model = SystemJob model = SystemJob
event_model = SystemJobEvent event_model = SystemJobEvent

View File

@@ -20,6 +20,9 @@ import ansible_runner
# django-ansible-base # django-ansible-base
from ansible_base.lib.utils.db import advisory_lock from ansible_base.lib.utils.db import advisory_lock
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx.main.utils.execution_environments import get_default_pod_spec from awx.main.utils.execution_environments import get_default_pod_spec
from awx.main.exceptions import ReceptorNodeNotFound from awx.main.exceptions import ReceptorNodeNotFound
@@ -32,7 +35,6 @@ from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER
from awx.main.tasks.signals import signal_state, signal_callback, SignalExit from awx.main.tasks.signals import signal_state, signal_callback, SignalExit
from awx.main.models import Instance, InstanceLink, UnifiedJob, ReceptorAddress from awx.main.models import Instance, InstanceLink, UnifiedJob, ReceptorAddress
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as task_awx
# Receptorctl # Receptorctl
from receptorctl.socket_interface import ReceptorControl from receptorctl.socket_interface import ReceptorControl
@@ -852,7 +854,7 @@ def reload_receptor():
raise RuntimeError("Receptor reload failed") raise RuntimeError("Receptor reload failed")
@task_awx(on_duplicate='queue_one') @task(on_duplicate='queue_one')
def write_receptor_config(): def write_receptor_config():
""" """
This task runs async on each control node, K8S only. This task runs async on each control node, K8S only.
@@ -875,7 +877,7 @@ def write_receptor_config():
reload_receptor() reload_receptor()
@task_awx(queue=get_task_queuename, on_duplicate='discard') @task(queue=get_task_queuename, on_duplicate='discard')
def remove_deprovisioned_node(hostname): def remove_deprovisioned_node(hostname):
InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)
InstanceLink.objects.filter(target__instance__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) InstanceLink.objects.filter(target__instance__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)

View File

@@ -69,7 +69,7 @@ def signal_callback():
def with_signal_handling(f): def with_signal_handling(f):
""" """
Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT. Change signal handling to make signal_callback return True in event of SIGTERM, SIGINT, or SIGUSR1.
""" """
@functools.wraps(f) @functools.wraps(f)

View File

@@ -9,12 +9,12 @@ import shutil
import time import time
from collections import namedtuple from collections import namedtuple
from contextlib import redirect_stdout from contextlib import redirect_stdout
from datetime import datetime
from packaging.version import Version from packaging.version import Version
from io import StringIO from io import StringIO
# dispatcherd # dispatcherd
from dispatcherd.factories import get_control_from_settings from dispatcherd.factories import get_control_from_settings
from dispatcherd.publish import task
# Runner # Runner
import ansible_runner.cleanup import ansible_runner.cleanup
@@ -56,7 +56,6 @@ from awx.main.analytics.subsystem_metrics import DispatcherMetrics
from awx.main.constants import ACTIVE_STATES, ERROR_STATES from awx.main.constants import ACTIVE_STATES, ERROR_STATES
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
from awx.main.dispatch import get_task_queuename, reaper from awx.main.dispatch import get_task_queuename, reaper
from awx.main.dispatch.publish import task as task_awx
from awx.main.models import ( from awx.main.models import (
Instance, Instance,
InstanceGroup, InstanceGroup,
@@ -74,7 +73,6 @@ from awx.main.tasks.host_indirect import save_indirect_host_entries
from awx.main.tasks.receptor import administrative_workunit_reaper, get_receptor_ctl, worker_cleanup, worker_info, write_receptor_config from awx.main.tasks.receptor import administrative_workunit_reaper, get_receptor_ctl, worker_cleanup, worker_info, write_receptor_config
from awx.main.utils.common import ignore_inventory_computed_fields, ignore_inventory_group_removal from awx.main.utils.common import ignore_inventory_computed_fields, ignore_inventory_group_removal
from awx.main.utils.reload import stop_local_services from awx.main.utils.reload import stop_local_services
from dispatcherd.publish import task
logger = logging.getLogger('awx.main.tasks.system') logger = logging.getLogger('awx.main.tasks.system')
@@ -95,7 +93,10 @@ def _run_dispatch_startup_common():
# TODO: Enable this on VM installs # TODO: Enable this on VM installs
if settings.IS_K8S: if settings.IS_K8S:
write_receptor_config() try:
write_receptor_config()
except Exception:
logger.exception("Failed to write receptor config, skipping.")
try: try:
convert_jsonfields() convert_jsonfields()
@@ -125,20 +126,12 @@ def _run_dispatch_startup_common():
# no-op. # no-op.
# #
apply_cluster_membership_policies() apply_cluster_membership_policies()
cluster_node_heartbeat() cluster_node_heartbeat(None)
reaper.startup_reaping() reaper.startup_reaping()
m = DispatcherMetrics() m = DispatcherMetrics()
m.reset_values() m.reset_values()
def _legacy_dispatch_startup():
"""
Legacy branch for startup: simply performs reaping of waiting jobs with a zero grace period.
"""
logger.debug("Legacy dispatcher: calling reaper.reap_waiting with grace_period=0")
reaper.reap_waiting(grace_period=0)
def _dispatcherd_dispatch_startup(): def _dispatcherd_dispatch_startup():
""" """
New dispatcherd branch for startup: uses the control API to re-submit waiting jobs. New dispatcherd branch for startup: uses the control API to re-submit waiting jobs.
@@ -153,21 +146,16 @@ def dispatch_startup():
""" """
System initialization at startup. System initialization at startup.
First, execute the common logic. First, execute the common logic.
Then, if FEATURE_DISPATCHERD_ENABLED is enabled, re-submit waiting jobs via the control API; Then, re-submit waiting jobs via the control API.
otherwise, fall back to legacy reaping of waiting jobs.
""" """
_run_dispatch_startup_common() _run_dispatch_startup_common()
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): _dispatcherd_dispatch_startup()
_dispatcherd_dispatch_startup()
else:
_legacy_dispatch_startup()
def inform_cluster_of_shutdown(): def inform_cluster_of_shutdown():
""" """
Clean system shutdown that marks the current instance offline. Clean system shutdown that marks the current instance offline.
In legacy mode, it also reaps waiting jobs. Relies on dispatcherd's built-in cleanup.
In dispatcherd mode, it relies on dispatcherd's built-in cleanup.
""" """
try: try:
inst = Instance.objects.get(hostname=settings.CLUSTER_HOST_ID) inst = Instance.objects.get(hostname=settings.CLUSTER_HOST_ID)
@@ -176,18 +164,11 @@ def inform_cluster_of_shutdown():
logger.exception("Cluster host not found: %s", settings.CLUSTER_HOST_ID) logger.exception("Cluster host not found: %s", settings.CLUSTER_HOST_ID)
return return
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): logger.debug("No extra reaping required for instance %s", inst.hostname)
logger.debug("Dispatcherd mode: no extra reaping required for instance %s", inst.hostname)
else:
try:
logger.debug("Legacy mode: reaping waiting jobs for instance %s", inst.hostname)
reaper.reap_waiting(inst, grace_period=0)
except Exception:
logger.exception("Failed to reap waiting jobs for %s", inst.hostname)
logger.warning("Normal shutdown processed for instance %s; instance removed from capacity pool.", inst.hostname) logger.warning("Normal shutdown processed for instance %s; instance removed from capacity pool.", inst.hostname)
@task_awx(queue=get_task_queuename, timeout=3600 * 5) @task(queue=get_task_queuename, timeout=3600 * 5)
def migrate_jsonfield(table, pkfield, columns): def migrate_jsonfield(table, pkfield, columns):
batchsize = 10000 batchsize = 10000
with advisory_lock(f'json_migration_{table}', wait=False) as acquired: with advisory_lock(f'json_migration_{table}', wait=False) as acquired:
@@ -233,7 +214,7 @@ def migrate_jsonfield(table, pkfield, columns):
logger.warning(f"Migration of {table} to jsonb is finished.") logger.warning(f"Migration of {table} to jsonb is finished.")
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one')
def apply_cluster_membership_policies(): def apply_cluster_membership_policies():
from awx.main.signals import disable_activity_stream from awx.main.signals import disable_activity_stream
@@ -345,7 +326,7 @@ def apply_cluster_membership_policies():
logger.debug('Cluster policy computation finished in {} seconds'.format(time.time() - started_compute)) logger.debug('Cluster policy computation finished in {} seconds'.format(time.time() - started_compute))
@task_awx(queue='tower_settings_change', timeout=600) @task(queue='tower_settings_change', timeout=600)
def clear_setting_cache(setting_keys): def clear_setting_cache(setting_keys):
# log that cache is being cleared # log that cache is being cleared
logger.info(f"clear_setting_cache of keys {setting_keys}") logger.info(f"clear_setting_cache of keys {setting_keys}")
@@ -363,7 +344,7 @@ def clear_setting_cache(setting_keys):
ctl.control('set_log_level', data={'level': settings.LOG_AGGREGATOR_LEVEL}) ctl.control('set_log_level', data={'level': settings.LOG_AGGREGATOR_LEVEL})
@task_awx(queue='tower_broadcast_all', timeout=600) @task(queue='tower_broadcast_all', timeout=600)
def delete_project_files(project_path): def delete_project_files(project_path):
# TODO: possibly implement some retry logic # TODO: possibly implement some retry logic
lock_file = project_path + '.lock' lock_file = project_path + '.lock'
@@ -381,7 +362,7 @@ def delete_project_files(project_path):
logger.exception('Could not remove lock file {}'.format(lock_file)) logger.exception('Could not remove lock file {}'.format(lock_file))
@task_awx(queue='tower_broadcast_all') @task(queue='tower_broadcast_all')
def profile_sql(threshold=1, minutes=1): def profile_sql(threshold=1, minutes=1):
if threshold <= 0: if threshold <= 0:
cache.delete('awx-profile-sql-threshold') cache.delete('awx-profile-sql-threshold')
@@ -391,7 +372,7 @@ def profile_sql(threshold=1, minutes=1):
logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes)) logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes))
@task_awx(queue=get_task_queuename, timeout=1800) @task(queue=get_task_queuename, timeout=1800)
def send_notifications(notification_list, job_id=None): def send_notifications(notification_list, job_id=None):
if not isinstance(notification_list, list): if not isinstance(notification_list, list):
raise TypeError("notification_list should be of type list") raise TypeError("notification_list should be of type list")
@@ -436,13 +417,13 @@ def events_processed_hook(unified_job):
save_indirect_host_entries.delay(unified_job.id) save_indirect_host_entries.delay(unified_job.id)
@task_awx(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') @task(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard')
def gather_analytics(): def gather_analytics():
if is_run_threshold_reached(getattr(settings, 'AUTOMATION_ANALYTICS_LAST_GATHER', None), settings.AUTOMATION_ANALYTICS_GATHER_INTERVAL): if is_run_threshold_reached(getattr(settings, 'AUTOMATION_ANALYTICS_LAST_GATHER', None), settings.AUTOMATION_ANALYTICS_GATHER_INTERVAL):
analytics.gather() analytics.gather()
@task_awx(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=600, on_duplicate='queue_one')
def purge_old_stdout_files(): def purge_old_stdout_files():
nowtime = time.time() nowtime = time.time()
for f in os.listdir(settings.JOBOUTPUT_ROOT): for f in os.listdir(settings.JOBOUTPUT_ROOT):
@@ -504,18 +485,18 @@ class CleanupImagesAndFiles:
cls.run_remote(this_inst, **kwargs) cls.run_remote(this_inst, **kwargs)
@task_awx(queue='tower_broadcast_all', timeout=3600) @task(queue='tower_broadcast_all', timeout=3600)
def handle_removed_image(remove_images=None): def handle_removed_image(remove_images=None):
"""Special broadcast invocation of this method to handle case of deleted EE""" """Special broadcast invocation of this method to handle case of deleted EE"""
CleanupImagesAndFiles.run(remove_images=remove_images, file_pattern='') CleanupImagesAndFiles.run(remove_images=remove_images, file_pattern='')
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one')
def cleanup_images_and_files(): def cleanup_images_and_files():
CleanupImagesAndFiles.run(image_prune=True) CleanupImagesAndFiles.run(image_prune=True)
@task_awx(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=600, on_duplicate='queue_one')
def execution_node_health_check(node): def execution_node_health_check(node):
if node == '': if node == '':
logger.warning('Remote health check incorrectly called with blank string') logger.warning('Remote health check incorrectly called with blank string')
@@ -640,44 +621,13 @@ def inspect_execution_and_hop_nodes(instance_list):
execution_node_health_check.apply_async([hostname]) execution_node_health_check.apply_async([hostname])
@task_awx(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
"""
Original implementation for AWX dispatcher.
Uses worker_tasks from bind_kwargs to track running tasks.
"""
# Run common instance management logic
this_inst, instance_list, lost_instances = _heartbeat_instance_management()
if this_inst is None:
return # Early return case from instance management
# Check versions
_heartbeat_check_versions(this_inst, instance_list)
# Handle lost instances
_heartbeat_handle_lost_instances(lost_instances, this_inst)
# Run local reaper - original implementation using worker_tasks
if worker_tasks is not None:
active_task_ids = []
for task_list in worker_tasks.values():
active_task_ids.extend(task_list)
# Convert dispatch_time to datetime
ref_time = datetime.fromisoformat(dispatch_time) if dispatch_time else now()
reaper.reap(instance=this_inst, excluded_uuids=active_task_ids, ref_time=ref_time)
if max(len(task_list) for task_list in worker_tasks.values()) <= 1:
reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=ref_time)
@task(queue=get_task_queuename, bind=True) @task(queue=get_task_queuename, bind=True)
def adispatch_cluster_node_heartbeat(binder): def cluster_node_heartbeat(binder):
""" """
Dispatcherd implementation. Dispatcherd implementation.
Uses Control API to get running tasks. Uses Control API to get running tasks.
""" """
# Run common instance management logic # Run common instance management logic
this_inst, instance_list, lost_instances = _heartbeat_instance_management() this_inst, instance_list, lost_instances = _heartbeat_instance_management()
if this_inst is None: if this_inst is None:
@@ -690,6 +640,9 @@ def adispatch_cluster_node_heartbeat(binder):
_heartbeat_handle_lost_instances(lost_instances, this_inst) _heartbeat_handle_lost_instances(lost_instances, this_inst)
# Get running tasks using dispatcherd API # Get running tasks using dispatcherd API
if binder is None:
logger.debug("Heartbeat finished in startup.")
return
active_task_ids = _get_active_task_ids_from_dispatcherd(binder) active_task_ids = _get_active_task_ids_from_dispatcherd(binder)
if active_task_ids is None: if active_task_ids is None:
logger.warning("No active task IDs retrieved from dispatcherd, skipping reaper") logger.warning("No active task IDs retrieved from dispatcherd, skipping reaper")
@@ -807,14 +760,16 @@ def _heartbeat_check_versions(this_inst, instance_list):
def _heartbeat_handle_lost_instances(lost_instances, this_inst): def _heartbeat_handle_lost_instances(lost_instances, this_inst):
"""Handle lost instances by reaping their jobs and marking them offline.""" """Handle lost instances by reaping their running jobs and marking them offline."""
for other_inst in lost_instances: for other_inst in lost_instances:
try: try:
# Any jobs marked as running will be marked as error
explanation = "Job reaped due to instance shutdown" explanation = "Job reaped due to instance shutdown"
reaper.reap(other_inst, job_explanation=explanation) reaper.reap(other_inst, job_explanation=explanation)
reaper.reap_waiting(other_inst, grace_period=0, job_explanation=explanation) # Any jobs that were waiting to be processed by this node will be handed back to task manager
UnifiedJob.objects.filter(status='waiting', controller_node=other_inst.hostname).update(status='pending', controller_node='', execution_node='')
except Exception: except Exception:
logger.exception('failed to reap jobs for {}'.format(other_inst.hostname)) logger.exception('failed to re-process jobs for lost instance {}'.format(other_inst.hostname))
try: try:
if settings.AWX_AUTO_DEPROVISION_INSTANCES and other_inst.node_type == "control": if settings.AWX_AUTO_DEPROVISION_INSTANCES and other_inst.node_type == "control":
deprovision_hostname = other_inst.hostname deprovision_hostname = other_inst.hostname
@@ -839,7 +794,7 @@ def _heartbeat_handle_lost_instances(lost_instances, this_inst):
logger.exception('No SQL state available. Error marking {} as lost'.format(other_inst.hostname)) logger.exception('No SQL state available. Error marking {} as lost'.format(other_inst.hostname))
@task_awx(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one')
def awx_receptor_workunit_reaper(): def awx_receptor_workunit_reaper():
""" """
When an AWX job is launched via receptor, files such as status, stdin, and stdout are created When an AWX job is launched via receptor, files such as status, stdin, and stdout are created
@@ -885,7 +840,7 @@ def awx_receptor_workunit_reaper():
administrative_workunit_reaper(receptor_work_list) administrative_workunit_reaper(receptor_work_list)
@task_awx(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one')
def awx_k8s_reaper(): def awx_k8s_reaper():
if not settings.RECEPTOR_RELEASE_WORK: if not settings.RECEPTOR_RELEASE_WORK:
return return
@@ -908,7 +863,7 @@ def awx_k8s_reaper():
logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group)) logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group))
@task_awx(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') @task(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard')
def awx_periodic_scheduler(): def awx_periodic_scheduler():
lock_session_timeout_milliseconds = settings.TASK_MANAGER_LOCK_TIMEOUT * 1000 lock_session_timeout_milliseconds = settings.TASK_MANAGER_LOCK_TIMEOUT * 1000
with advisory_lock('awx_periodic_scheduler_lock', lock_session_timeout_milliseconds=lock_session_timeout_milliseconds, wait=False) as acquired: with advisory_lock('awx_periodic_scheduler_lock', lock_session_timeout_milliseconds=lock_session_timeout_milliseconds, wait=False) as acquired:
@@ -965,7 +920,7 @@ def awx_periodic_scheduler():
emit_channel_notification('schedules-changed', dict(id=schedule.id, group_name="schedules")) emit_channel_notification('schedules-changed', dict(id=schedule.id, group_name="schedules"))
@task_awx(queue=get_task_queuename, timeout=3600) @task(queue=get_task_queuename, timeout=3600)
def handle_failure_notifications(task_ids): def handle_failure_notifications(task_ids):
"""A task-ified version of the method that sends notifications.""" """A task-ified version of the method that sends notifications."""
found_task_ids = set() found_task_ids = set()
@@ -980,7 +935,7 @@ def handle_failure_notifications(task_ids):
logger.warning(f'Could not send notifications for {deleted_tasks} because they were not found in the database') logger.warning(f'Could not send notifications for {deleted_tasks} because they were not found in the database')
@task_awx(queue=get_task_queuename, timeout=3600 * 5) @task(queue=get_task_queuename, timeout=3600 * 5)
def update_inventory_computed_fields(inventory_id): def update_inventory_computed_fields(inventory_id):
""" """
Signal handler and wrapper around inventory.update_computed_fields to Signal handler and wrapper around inventory.update_computed_fields to
@@ -1030,7 +985,7 @@ def update_smart_memberships_for_inventory(smart_inventory):
return False return False
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one')
def update_host_smart_inventory_memberships(): def update_host_smart_inventory_memberships():
smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False) smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False)
changed_inventories = set([]) changed_inventories = set([])
@@ -1046,7 +1001,7 @@ def update_host_smart_inventory_memberships():
smart_inventory.update_computed_fields() smart_inventory.update_computed_fields()
@task_awx(queue=get_task_queuename, timeout=3600 * 5) @task(queue=get_task_queuename, timeout=3600 * 5)
def delete_inventory(inventory_id, user_id, retries=5): def delete_inventory(inventory_id, user_id, retries=5):
# Delete inventory as user # Delete inventory as user
if user_id is None: if user_id is None:
@@ -1108,7 +1063,7 @@ def _reconstruct_relationships(copy_mapping):
new_obj.save() new_obj.save()
@task_awx(queue=get_task_queuename, timeout=600) @task(queue=get_task_queuename, timeout=600)
def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, permission_check_func=None): def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, permission_check_func=None):
logger.debug('Deep copy {} from {} to {}.'.format(model_name, obj_pk, new_obj_pk)) logger.debug('Deep copy {} from {} to {}.'.format(model_name, obj_pk, new_obj_pk))
@@ -1163,7 +1118,7 @@ def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, p
update_inventory_computed_fields.delay(new_obj.id) update_inventory_computed_fields.delay(new_obj.id)
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='discard') @task(queue=get_task_queuename, timeout=3600, on_duplicate='discard')
def periodic_resource_sync(): def periodic_resource_sync():
if not getattr(settings, 'RESOURCE_SERVER', None): if not getattr(settings, 'RESOURCE_SERVER', None):
logger.debug("Skipping periodic resource_sync, RESOURCE_SERVER not configured") logger.debug("Skipping periodic resource_sync, RESOURCE_SERVER not configured")

View File

@@ -6,14 +6,13 @@ from dispatcherd.publish import task
from django.db import connection from django.db import connection
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as old_task
from ansible_base.lib.utils.db import advisory_lock from ansible_base.lib.utils.db import advisory_lock
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@old_task(queue=get_task_queuename) @task(queue=get_task_queuename)
def sleep_task(seconds=10, log=False): def sleep_task(seconds=10, log=False):
if log: if log:
logger.info('starting sleep_task') logger.info('starting sleep_task')

View File

@@ -1,8 +1,11 @@
import pytest import pytest
from django.test import RequestFactory
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from rest_framework.request import Request
from awx.main import models from awx.main import models
from awx.main.analytics.metrics import metrics from awx.main.analytics.metrics import metrics
from awx.main.analytics.dispatcherd_metrics import get_dispatcherd_metrics
from awx.api.versioning import reverse from awx.api.versioning import reverse
EXPECTED_VALUES = { EXPECTED_VALUES = {
@@ -77,3 +80,55 @@ def test_metrics_http_methods(get, post, patch, put, options, admin):
assert patch(get_metrics_view_db_only(), user=admin).status_code == 405 assert patch(get_metrics_view_db_only(), user=admin).status_code == 405
assert post(get_metrics_view_db_only(), user=admin).status_code == 405 assert post(get_metrics_view_db_only(), user=admin).status_code == 405
assert options(get_metrics_view_db_only(), user=admin).status_code == 200 assert options(get_metrics_view_db_only(), user=admin).status_code == 200
class DummyMetricsResponse:
def __init__(self, payload):
self._payload = payload
def read(self):
return self._payload
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def test_dispatcherd_metrics_node_filter_match(mocker, settings):
settings.CLUSTER_HOST_ID = "awx-1"
payload = b'# HELP test_metric A test metric\n# TYPE test_metric gauge\ntest_metric 1\n'
def fake_urlopen(url, timeout=1.0):
return DummyMetricsResponse(payload)
mocker.patch('urllib.request.urlopen', fake_urlopen)
request = Request(RequestFactory().get('/api/v2/metrics/', {'node': 'awx-1'}))
assert get_dispatcherd_metrics(request) == payload.decode('utf-8')
def test_dispatcherd_metrics_node_filter_excludes_local(mocker, settings):
settings.CLUSTER_HOST_ID = "awx-1"
def fake_urlopen(*args, **kwargs):
raise AssertionError("urlopen should not be called when node filter excludes local node")
mocker.patch('urllib.request.urlopen', fake_urlopen)
request = Request(RequestFactory().get('/api/v2/metrics/', {'node': 'awx-2'}))
assert get_dispatcherd_metrics(request) == ''
def test_dispatcherd_metrics_metric_filter_excludes_unrelated(mocker):
def fake_urlopen(*args, **kwargs):
raise AssertionError("urlopen should not be called when metric filter excludes dispatcherd metrics")
mocker.patch('urllib.request.urlopen', fake_urlopen)
request = Request(RequestFactory().get('/api/v2/metrics/', {'metric': 'awx_system_info'}))
assert get_dispatcherd_metrics(request) == ''

View File

@@ -463,6 +463,26 @@ class TestInventorySourceCredential:
assert 'Cloud-based inventory sources (such as ec2)' in r.data['credential'][0] assert 'Cloud-based inventory sources (such as ec2)' in r.data['credential'][0]
assert 'require credentials for the matching cloud service' in r.data['credential'][0] assert 'require credentials for the matching cloud service' in r.data['credential'][0]
def test_credential_dict_value_returns_400(self, inventory, admin_user, put):
"""Passing a dict for the credential field should return 400, not 500.
Reproduces a bug where int() raises TypeError on non-scalar types
(dict, list) which was uncaught, resulting in a 500 Internal Server Error.
"""
inv_src = InventorySource.objects.create(name='test-src', inventory=inventory, source='ec2')
r = put(
url=reverse('api:inventory_source_detail', kwargs={'pk': inv_src.pk}),
data={
'name': 'test-src',
'inventory': inventory.pk,
'source': 'ec2',
'credential': {'username': 'admin', 'password': 'secret'},
},
user=admin_user,
expect=400,
)
assert r.status_code == 400
def test_vault_credential_not_allowed(self, project, inventory, vault_credential, admin_user, post): def test_vault_credential_not_allowed(self, project, inventory, vault_credential, admin_user, post):
"""Vault credentials cannot be associated via the deprecated field""" """Vault credentials cannot be associated via the deprecated field"""
# TODO: when feature is added, add tests to use the related credentials # TODO: when feature is added, add tests to use the related credentials

View File

@@ -485,3 +485,47 @@ class TestJobTemplateCallbackProxyIntegration:
expect=400, expect=400,
**headers **headers
) )
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[])
def test_only_first_entry_in_comma_separated_header_is_considered(self, job_template, admin_user, post):
"""
Test that only the first entry in a comma-separated header value is used for host matching.
This is important for X-Forwarded-For style headers where the format is "client, proxy1, proxy2".
Only the original client (first entry) should be matched against inventory hosts.
"""
# Create host that matches the SECOND entry in the comma-separated list
job_template.inventory.hosts.create(name='second-host.example.com')
headers = {
# First entry is 'first-host.example.com', second is 'second-host.example.com'
# Only the first should be considered, so this should NOT match
'HTTP_X_FROM_THE_LOAD_BALANCER': 'first-host.example.com, second-host.example.com',
'REMOTE_ADDR': 'unrelated-addr',
'REMOTE_HOST': 'unrelated-host',
}
# Should return 400 because only 'first-host.example.com' is considered,
# and that host is NOT in the inventory
r = post(
url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=400, **headers
)
assert r.data['msg'] == 'No matching host could be found!'
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[])
def test_first_entry_in_comma_separated_header_matches(self, job_template, admin_user, post):
"""
Test that the first entry in a comma-separated header value correctly matches an inventory host.
"""
# Create host that matches the FIRST entry in the comma-separated list
job_template.inventory.hosts.create(name='first-host.example.com')
headers = {
# First entry is 'first-host.example.com', second is 'second-host.example.com'
# The first entry matches the inventory host
'HTTP_X_FROM_THE_LOAD_BALANCER': 'first-host.example.com, second-host.example.com',
'REMOTE_ADDR': 'unrelated-addr',
'REMOTE_HOST': 'unrelated-host',
}
# Should return 201 because 'first-host.example.com' is the first entry and matches
post(url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=201, **headers)

View File

@@ -21,7 +21,7 @@ def test_feature_flags_list_endpoint_override(get, flag_val):
bob = User.objects.create(username='bob', password='test_user', is_superuser=True) bob = User.objects.create(username='bob', password='test_user', is_superuser=True)
AAPFlag.objects.all().delete() AAPFlag.objects.all().delete()
flag_name = "FEATURE_DISPATCHERD_ENABLED" flag_name = "FEATURE_INDIRECT_NODE_COUNTING_ENABLED"
setattr(settings, flag_name, flag_val) setattr(settings, flag_name, flag_val)
seed_feature_flags() seed_feature_flags()
url = "/api/v2/feature_flags/states/" url = "/api/v2/feature_flags/states/"

View File

@@ -0,0 +1,17 @@
import pytest
from awx.main.dispatch.config import get_dispatcherd_config
from awx.main.management.commands.dispatcherd import _hash_config
@pytest.mark.django_db
def test_dispatcherd_config_hash_is_stable(settings, monkeypatch):
monkeypatch.setenv('AWX_COMPONENT', 'dispatcher')
settings.CLUSTER_HOST_ID = 'test-node'
settings.JOB_EVENT_WORKERS = 1
settings.DISPATCHER_SCHEDULE = {}
config_one = get_dispatcherd_config(for_service=True)
config_two = get_dispatcherd_config(for_service=True)
assert _hash_config(config_one) == _hash_config(config_two)

View File

@@ -3,7 +3,7 @@ import pytest
# AWX # AWX
from awx.main.ha import is_ha_environment from awx.main.ha import is_ha_environment
from awx.main.models.ha import Instance from awx.main.models.ha import Instance
from awx.main.dispatch.pool import get_auto_max_workers from awx.main.utils.common import get_auto_max_workers
# Django # Django
from django.test.utils import override_settings from django.test.utils import override_settings

View File

@@ -1,5 +1,6 @@
import itertools import itertools
import pytest import pytest
from uuid import uuid4
# CRUM # CRUM
from crum import impersonate from crum import impersonate
@@ -33,6 +34,64 @@ def test_soft_unique_together(post, project, admin_user):
assert 'combination already exists' in str(r.data) assert 'combination already exists' in str(r.data)
@pytest.mark.django_db
class TestJobCancel:
"""
Coverage for UnifiedJob.cancel, focused on interaction with dispatcherd objects.
Using mocks for the dispatcherd objects, because tests by default use a no-op broker.
"""
def test_cancel_sets_flag_and_clears_start_args(self, mocker):
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo', start_args='{"secret": "value"}')
job.websocket_emit_status = mocker.MagicMock()
assert job.can_cancel is True
assert job.cancel_flag is False
job.cancel()
job.refresh_from_db()
assert job.cancel_flag is True
assert job.start_args == ''
def test_cancel_sets_job_explanation(self, mocker):
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo')
job.websocket_emit_status = mocker.MagicMock()
job_explanation = 'giggity giggity'
job.cancel(job_explanation=job_explanation)
job.refresh_from_db()
assert job.job_explanation == job_explanation
def test_cancel_sends_control_message(self, mocker):
celery_task_id = str(uuid4())
job = Job.objects.create(status='running', name='foo-job', celery_task_id=celery_task_id, controller_node='foo')
job.websocket_emit_status = mocker.MagicMock()
control = mocker.MagicMock()
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
job.cancel()
get_control.assert_called_once_with(default_publish_channel='foo')
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
def test_cancel_refreshes_task_id_before_sending_control(self, mocker):
job = Job.objects.create(status='pending', name='foo-job', celery_task_id='', controller_node='bar')
job.websocket_emit_status = mocker.MagicMock()
celery_task_id = str(uuid4())
Job.objects.filter(pk=job.pk).update(status='running', celery_task_id=celery_task_id)
control = mocker.MagicMock()
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
refresh_spy = mocker.spy(job, 'refresh_from_db')
job.cancel()
refresh_spy.assert_called_once_with(fields=['celery_task_id', 'controller_node'])
get_control.assert_called_once_with(default_publish_channel='bar')
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
@pytest.mark.django_db @pytest.mark.django_db
class TestCreateUnifiedJob: class TestCreateUnifiedJob:
""" """

View File

@@ -9,7 +9,7 @@ from unittest import mock
import pytest import pytest
from awx.main.tasks.system import CleanupImagesAndFiles, execution_node_health_check, inspect_established_receptor_connections, clear_setting_cache from awx.main.tasks.system import CleanupImagesAndFiles, execution_node_health_check, inspect_established_receptor_connections, clear_setting_cache
from awx.main.management.commands.run_dispatcher import Command from awx.main.management.commands.dispatcherd import Command
from awx.main.models import Instance, Job, ReceptorAddress, InstanceLink from awx.main.models import Instance, Job, ReceptorAddress, InstanceLink

View File

@@ -1,20 +1,12 @@
import datetime import datetime
import multiprocessing
import random
import signal
import time
import yaml
from unittest import mock from unittest import mock
from flags.state import disable_flag, enable_flag
from django.utils.timezone import now as tz_now from django.utils.timezone import now as tz_now
import pytest import pytest
from awx.main.models import Job, WorkflowJob, Instance from awx.main.models import Job, WorkflowJob, Instance
from awx.main.dispatch import reaper from awx.main.dispatch import reaper
from awx.main.dispatch.pool import StatefulPoolWorker, WorkerPool, AutoscalePool from awx.main.tasks import system
from awx.main.dispatch.publish import task from dispatcherd.publish import task
from awx.main.dispatch.worker import BaseWorker, TaskWorker
from awx.main.dispatch.periodic import Scheduler
''' '''
Prevent logger.<warn, debug, error> calls from triggering database operations Prevent logger.<warn, debug, error> calls from triggering database operations
@@ -57,294 +49,6 @@ def multiply(a, b):
return a * b return a * b
class SimpleWorker(BaseWorker):
def perform_work(self, body, *args):
pass
class ResultWriter(BaseWorker):
def perform_work(self, body, result_queue):
result_queue.put(body + '!!!')
class SlowResultWriter(BaseWorker):
def perform_work(self, body, result_queue):
time.sleep(3)
super(SlowResultWriter, self).perform_work(body, result_queue)
@pytest.mark.usefixtures("disable_database_settings")
class TestPoolWorker:
def setup_method(self, test_method):
self.worker = StatefulPoolWorker(1000, self.tick, tuple())
def tick(self):
self.worker.finished.put(self.worker.queue.get()['uuid'])
time.sleep(0.5)
def test_qsize(self):
assert self.worker.qsize == 0
for i in range(3):
self.worker.put({'task': 'abc123'})
assert self.worker.qsize == 3
def test_put(self):
assert len(self.worker.managed_tasks) == 0
assert self.worker.messages_finished == 0
self.worker.put({'task': 'abc123'})
assert len(self.worker.managed_tasks) == 1
assert self.worker.messages_sent == 1
def test_managed_tasks(self):
self.worker.put({'task': 'abc123'})
self.worker.calculate_managed_tasks()
assert len(self.worker.managed_tasks) == 1
self.tick()
self.worker.calculate_managed_tasks()
assert len(self.worker.managed_tasks) == 0
def test_current_task(self):
self.worker.put({'task': 'abc123'})
assert self.worker.current_task['task'] == 'abc123'
def test_quit(self):
self.worker.quit()
assert self.worker.queue.get() == 'QUIT'
def test_idle_busy(self):
assert self.worker.idle is True
assert self.worker.busy is False
self.worker.put({'task': 'abc123'})
assert self.worker.busy is True
assert self.worker.idle is False
@pytest.mark.django_db
class TestWorkerPool:
def setup_method(self, test_method):
self.pool = WorkerPool(min_workers=3)
def teardown_method(self, test_method):
self.pool.stop(signal.SIGTERM)
def test_worker(self):
self.pool.init_workers(SimpleWorker().work_loop)
assert len(self.pool) == 3
for worker in self.pool.workers:
assert worker.messages_sent == 0
assert worker.alive is True
def test_single_task(self):
self.pool.init_workers(SimpleWorker().work_loop)
self.pool.write(0, 'xyz')
assert self.pool.workers[0].messages_sent == 1 # worker at index 0 handled one task
assert self.pool.workers[1].messages_sent == 0
assert self.pool.workers[2].messages_sent == 0
def test_queue_preference(self):
self.pool.init_workers(SimpleWorker().work_loop)
self.pool.write(2, 'xyz')
assert self.pool.workers[0].messages_sent == 0
assert self.pool.workers[1].messages_sent == 0
assert self.pool.workers[2].messages_sent == 1 # worker at index 2 handled one task
def test_worker_processing(self):
result_queue = multiprocessing.Queue()
self.pool.init_workers(ResultWriter().work_loop, result_queue)
for i in range(10):
self.pool.write(random.choice(range(len(self.pool))), 'Hello, Worker {}'.format(i))
all_messages = [result_queue.get(timeout=1) for i in range(10)]
all_messages.sort()
assert all_messages == ['Hello, Worker {}!!!'.format(i) for i in range(10)]
total_handled = sum([worker.messages_sent for worker in self.pool.workers])
assert total_handled == 10
@pytest.mark.django_db
class TestAutoScaling:
def setup_method(self, test_method):
self.pool = AutoscalePool(min_workers=2, max_workers=10)
def teardown_method(self, test_method):
self.pool.stop(signal.SIGTERM)
def test_scale_up(self):
result_queue = multiprocessing.Queue()
self.pool.init_workers(SlowResultWriter().work_loop, result_queue)
# start with two workers, write an event to each worker and make it busy
assert len(self.pool) == 2
for i, w in enumerate(self.pool.workers):
w.put('Hello, Worker {}'.format(0))
assert len(self.pool) == 2
# wait for the subprocesses to start working on their tasks and be marked busy
time.sleep(1)
assert self.pool.should_grow
# write a third message, expect a new worker to spawn because all
# workers are busy
self.pool.write(0, 'Hello, Worker {}'.format(2))
assert len(self.pool) == 3
def test_scale_down(self):
self.pool.init_workers(ResultWriter().work_loop, multiprocessing.Queue())
# start with two workers, and scale up to 10 workers
assert len(self.pool) == 2
for i in range(8):
self.pool.up()
assert len(self.pool) == 10
# cleanup should scale down to 8 workers
self.pool.cleanup()
assert len(self.pool) == 2
def test_max_scale_up(self):
self.pool.init_workers(ResultWriter().work_loop, multiprocessing.Queue())
assert len(self.pool) == 2
for i in range(25):
self.pool.up()
assert self.pool.max_workers == 10
assert self.pool.full is True
assert len(self.pool) == 10
def test_equal_worker_distribution(self):
# if all workers are busy, spawn new workers *before* adding messages
# to an existing queue
self.pool.init_workers(SlowResultWriter().work_loop, multiprocessing.Queue)
# start with two workers, write an event to each worker and make it busy
assert len(self.pool) == 2
for i in range(10):
self.pool.write(0, 'Hello, World!')
assert len(self.pool) == 10
for w in self.pool.workers:
assert w.busy
assert len(w.managed_tasks) == 1
# the queue is full at 10, the _next_ write should put the message into
# a worker's backlog
assert len(self.pool) == 10
for w in self.pool.workers:
assert w.messages_sent == 1
self.pool.write(0, 'Hello, World!')
assert len(self.pool) == 10
assert self.pool.workers[0].messages_sent == 2
@pytest.mark.timeout(20)
def test_lost_worker_autoscale(self):
# if a worker exits, it should be replaced automatically up to min_workers
self.pool.init_workers(ResultWriter().work_loop, multiprocessing.Queue())
# start with two workers, kill one of them
assert len(self.pool) == 2
assert not self.pool.should_grow
alive_pid = self.pool.workers[1].pid
self.pool.workers[0].process.kill()
self.pool.workers[0].process.join() # waits for process to full terminate
# clean up and the dead worker
self.pool.cleanup()
assert len(self.pool) == 1
assert self.pool.workers[0].pid == alive_pid
# the next queue write should replace the lost worker
self.pool.write(0, 'Hello, Worker')
assert len(self.pool) == 2
@pytest.mark.usefixtures("disable_database_settings")
class TestTaskDispatcher:
@property
def tm(self):
return TaskWorker()
def test_function_dispatch(self):
result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.add', 'args': [2, 2]})
assert result == 4
def test_function_dispatch_must_be_decorated(self):
result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.restricted', 'args': [2, 2]})
assert isinstance(result, ValueError)
assert str(result) == 'awx.main.tests.functional.test_dispatch.restricted is not decorated with @task()' # noqa
def test_method_dispatch(self):
result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.Adder', 'args': [2, 2]})
assert result == 4
def test_method_dispatch_must_be_decorated(self):
result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.Restricted', 'args': [2, 2]})
assert isinstance(result, ValueError)
assert str(result) == 'awx.main.tests.functional.test_dispatch.Restricted is not decorated with @task()' # noqa
def test_python_function_cannot_be_imported(self):
result = self.tm.perform_work(
{
'task': 'os.system',
'args': ['ls'],
}
)
assert isinstance(result, ValueError)
assert str(result) == 'os.system is not a valid awx task' # noqa
def test_undefined_function_cannot_be_imported(self):
result = self.tm.perform_work({'task': 'awx.foo.bar'})
assert isinstance(result, ModuleNotFoundError)
assert str(result) == "No module named 'awx.foo'" # noqa
@pytest.mark.django_db
class TestTaskPublisher:
@pytest.fixture(autouse=True)
def _disable_dispatcherd(self):
flag_name = "FEATURE_DISPATCHERD_ENABLED"
disable_flag(flag_name)
yield
enable_flag(flag_name)
def test_function_callable(self):
assert add(2, 2) == 4
def test_method_callable(self):
assert Adder().run(2, 2) == 4
def test_function_apply_async(self):
message, queue = add.apply_async([2, 2], queue='foobar')
assert message['args'] == [2, 2]
assert message['kwargs'] == {}
assert message['task'] == 'awx.main.tests.functional.test_dispatch.add'
assert queue == 'foobar'
def test_method_apply_async(self):
message, queue = Adder.apply_async([2, 2], queue='foobar')
assert message['args'] == [2, 2]
assert message['kwargs'] == {}
assert message['task'] == 'awx.main.tests.functional.test_dispatch.Adder'
assert queue == 'foobar'
def test_apply_async_queue_required(self):
with pytest.raises(ValueError) as e:
message, queue = add.apply_async([2, 2])
assert "awx.main.tests.functional.test_dispatch.add: Queue value required and may not be None" == e.value.args[0]
def test_queue_defined_in_task_decorator(self):
message, queue = multiply.apply_async([2, 2])
assert queue == 'hard-math'
def test_queue_overridden_from_task_decorator(self):
message, queue = multiply.apply_async([2, 2], queue='not-so-hard')
assert queue == 'not-so-hard'
def test_apply_with_callable_queuename(self):
message, queue = add.apply_async([2, 2], queue=lambda: 'called')
assert queue == 'called'
yesterday = tz_now() - datetime.timedelta(days=1) yesterday = tz_now() - datetime.timedelta(days=1)
minute = tz_now() - datetime.timedelta(seconds=120) minute = tz_now() - datetime.timedelta(seconds=120)
now = tz_now() now = tz_now()
@@ -358,11 +62,6 @@ class TestJobReaper(object):
('running', '', '', None, False), # running, not assigned to the instance ('running', '', '', None, False), # running, not assigned to the instance
('running', 'awx', '', None, True), # running, has the instance as its execution_node ('running', 'awx', '', None, True), # running, has the instance as its execution_node
('running', '', 'awx', None, True), # running, has the instance as its controller_node ('running', '', 'awx', None, True), # running, has the instance as its controller_node
('waiting', '', '', None, False), # waiting, not assigned to the instance
('waiting', 'awx', '', None, False), # waiting, was edited less than a minute ago
('waiting', '', 'awx', None, False), # waiting, was edited less than a minute ago
('waiting', 'awx', '', yesterday, False), # waiting, managed by another node, ignore
('waiting', '', 'awx', yesterday, True), # waiting, assigned to the controller_node, stale
], ],
) )
def test_should_reap(self, status, fail, execution_node, controller_node, modified): def test_should_reap(self, status, fail, execution_node, controller_node, modified):
@@ -380,7 +79,6 @@ class TestJobReaper(object):
# (because .save() overwrites it to _now_) # (because .save() overwrites it to _now_)
Job.objects.filter(id=j.id).update(modified=modified) Job.objects.filter(id=j.id).update(modified=modified)
reaper.reap(i) reaper.reap(i)
reaper.reap_waiting(i)
job = Job.objects.first() job = Job.objects.first()
if fail: if fail:
assert job.status == 'failed' assert job.status == 'failed'
@@ -389,6 +87,20 @@ class TestJobReaper(object):
else: else:
assert job.status == status assert job.status == status
def test_waiting_job_sent_back_to_pending(self):
this_inst = Instance(hostname='awx')
this_inst.save()
lost_inst = Instance(hostname='lost', node_type=Instance.Types.EXECUTION, node_state=Instance.States.UNAVAILABLE)
lost_inst.save()
job = Job.objects.create(status='waiting', controller_node=lost_inst.hostname, execution_node='lost')
system._heartbeat_handle_lost_instances([lost_inst], this_inst)
job.refresh_from_db()
assert job.status == 'pending'
assert job.controller_node == ''
assert job.execution_node == ''
@pytest.mark.parametrize( @pytest.mark.parametrize(
'excluded_uuids, fail, started', 'excluded_uuids, fail, started',
[ [
@@ -448,76 +160,3 @@ class TestJobReaper(object):
assert job.started > ref_time assert job.started > ref_time
assert job.status == 'running' assert job.status == 'running'
assert job.job_explanation == '' assert job.job_explanation == ''
@pytest.mark.django_db
class TestScheduler:
def test_too_many_schedules_freak_out(self):
with pytest.raises(RuntimeError):
Scheduler({'job1': {'schedule': datetime.timedelta(seconds=1)}, 'job2': {'schedule': datetime.timedelta(seconds=1)}})
def test_spread_out(self):
scheduler = Scheduler(
{
'job1': {'schedule': datetime.timedelta(seconds=16)},
'job2': {'schedule': datetime.timedelta(seconds=16)},
'job3': {'schedule': datetime.timedelta(seconds=16)},
'job4': {'schedule': datetime.timedelta(seconds=16)},
}
)
assert [job.offset for job in scheduler.jobs] == [0, 4, 8, 12]
def test_missed_schedule(self, mocker):
scheduler = Scheduler({'job1': {'schedule': datetime.timedelta(seconds=10)}})
assert scheduler.jobs[0].missed_runs(time.time() - scheduler.global_start) == 0
mocker.patch('awx.main.dispatch.periodic.time.time', return_value=scheduler.global_start + 50)
scheduler.get_and_mark_pending()
assert scheduler.jobs[0].missed_runs(50) > 1
def test_advance_schedule(self, mocker):
scheduler = Scheduler(
{
'job1': {'schedule': datetime.timedelta(seconds=30)},
'joba': {'schedule': datetime.timedelta(seconds=20)},
'jobb': {'schedule': datetime.timedelta(seconds=20)},
}
)
for job in scheduler.jobs:
# HACK: the offsets automatically added make this a hard test to write... so remove offsets
job.offset = 0.0
mocker.patch('awx.main.dispatch.periodic.time.time', return_value=scheduler.global_start + 29)
to_run = scheduler.get_and_mark_pending()
assert set(job.name for job in to_run) == set(['joba', 'jobb'])
mocker.patch('awx.main.dispatch.periodic.time.time', return_value=scheduler.global_start + 39)
to_run = scheduler.get_and_mark_pending()
assert len(to_run) == 1
assert to_run[0].name == 'job1'
@staticmethod
def get_job(scheduler, name):
for job in scheduler.jobs:
if job.name == name:
return job
def test_scheduler_debug(self, mocker):
scheduler = Scheduler(
{
'joba': {'schedule': datetime.timedelta(seconds=20)},
'jobb': {'schedule': datetime.timedelta(seconds=50)},
'jobc': {'schedule': datetime.timedelta(seconds=500)},
'jobd': {'schedule': datetime.timedelta(seconds=20)},
}
)
rel_time = 119.9 # slightly under the 6th 20-second bin, to avoid offset problems
current_time = scheduler.global_start + rel_time
mocker.patch('awx.main.dispatch.periodic.time.time', return_value=current_time - 1.0e-8)
self.get_job(scheduler, 'jobb').mark_run(rel_time)
self.get_job(scheduler, 'jobd').mark_run(rel_time - 20.0)
output = scheduler.debug()
data = yaml.safe_load(output)
assert data['schedule_list']['jobc']['last_run_seconds_ago'] is None
assert data['schedule_list']['joba']['missed_runs'] == 4
assert data['schedule_list']['jobd']['missed_runs'] == 3
assert data['schedule_list']['jobd']['completed_runs'] == 1
assert data['schedule_list']['jobb']['next_run_in_seconds'] > 25.0

View File

@@ -50,7 +50,7 @@ def test_job_capacity_and_with_inactive_node():
i.save() i.save()
with override_settings(CLUSTER_HOST_ID=i.hostname): with override_settings(CLUSTER_HOST_ID=i.hostname):
with mock.patch.object(redis.client.Redis, 'ping', lambda self: True): with mock.patch.object(redis.client.Redis, 'ping', lambda self: True):
cluster_node_heartbeat() cluster_node_heartbeat(None)
i = Instance.objects.get(id=i.id) i = Instance.objects.get(id=i.id)
assert i.capacity == 0 assert i.capacity == 0

View File

@@ -173,3 +173,54 @@ class TestMigrationSmoke:
assert Role.objects.filter( assert Role.objects.filter(
singleton_name='system_administrator', role_field='system_administrator' singleton_name='system_administrator', role_field='system_administrator'
).exists(), "expected to find a system_administrator singleton role" ).exists(), "expected to find a system_administrator singleton role"
@pytest.mark.django_db
class TestGithubAppBug:
"""
Tests that `awx-manage createsuperuser` runs successfully after
the `github_app` CredentialType kind is updated to `github_app_lookup`
via the migration.
"""
def test_after_github_app_kind_migration(self, migrator):
"""
Verifies that `createsuperuser` does not raise a KeyError
after the 0204_squashed_deletions migration (which includes
the `update_github_app_kind` logic) is applied.
"""
# 1. Apply migrations up to the point *before* the 0204_squashed_deletions migration.
# This simulates the state where the problematic CredentialType might exist.
# We use 0203_remove_team_of_teams as the direct predecessor.
old_state = migrator.apply_tested_migration(('main', '0203_remove_team_of_teams'))
# Get the CredentialType model from the historical state.
CredentialType = old_state.apps.get_model('main', 'CredentialType')
# Create a CredentialType with the old, problematic 'namespace' value
CredentialType.objects.create(
name='Legacy GitHub App Credential',
kind='external',
namespace='github_app', # The namespace that causes the KeyError in the registry lookup
managed=True,
created=now(),
modified=now(),
)
# Apply the migration that includes the fix (0204_squashed_deletions).
new_state = migrator.apply_tested_migration(('main', '0204_squashed_deletions'))
# Verify that the CredentialType with the old 'kind' no longer exists
# and the 'kind' has been updated to the new value.
CredentialType = new_state.apps.get_model('main', 'CredentialType') # Get CredentialType model from the new state
# Assertion 1: The CredentialType with the old 'github_app' kind should no longer exist.
assert not CredentialType.objects.filter(
namespace='github_app'
).exists(), "CredentialType with old 'github_app' kind should no longer exist after migration."
# Assertion 2: The CredentialType should now exist with the new 'github_app_lookup' kind
# and retain its original name.
assert CredentialType.objects.filter(
namespace='github_app_lookup', name='Legacy GitHub App Credential'
).exists(), "CredentialType should be updated to 'github_app_lookup' and retain its name."

View File

@@ -69,7 +69,7 @@ def live_tmp_folder():
settings._awx_conf_memoizedcache.clear() settings._awx_conf_memoizedcache.clear()
# cache is cleared in test environment, but need to clear in test environment # cache is cleared in test environment, but need to clear in test environment
clear_setting_cache.delay(['AWX_ISOLATION_SHOW_PATHS']) clear_setting_cache.delay(['AWX_ISOLATION_SHOW_PATHS'])
time.sleep(0.2) # allow task to finish, we have no real metric to know time.sleep(5.0) # for _awx_conf_memoizedcache to expire on all workers
else: else:
logger.info(f'Believed that {path} is already in settings.AWX_ISOLATION_SHOW_PATHS: {settings.AWX_ISOLATION_SHOW_PATHS}') logger.info(f'Believed that {path} is already in settings.AWX_ISOLATION_SHOW_PATHS: {settings.AWX_ISOLATION_SHOW_PATHS}')
return path return path

View File

@@ -7,9 +7,6 @@ from awx.settings.development import * # NOQA
# Some things make decisions based on settings.SETTINGS_MODULE, so this is done for that # Some things make decisions based on settings.SETTINGS_MODULE, so this is done for that
SETTINGS_MODULE = 'awx.settings.development' SETTINGS_MODULE = 'awx.settings.development'
# Turn off task submission, because sqlite3 does not have pg_notify
DISPATCHER_MOCK_PUBLISH = True
# Use SQLite for unit tests instead of PostgreSQL. If the lines below are # Use SQLite for unit tests instead of PostgreSQL. If the lines below are
# commented out, Django will create the test_awx-dev database in PostgreSQL to # commented out, Django will create the test_awx-dev database in PostgreSQL to
# run unit tests. # run unit tests.

View File

@@ -0,0 +1,49 @@
import pytest
from collections import OrderedDict
from unittest import mock
from rest_framework.exceptions import ValidationError
from awx.api.fields import DeprecatedCredentialField
class TestDeprecatedCredentialField:
"""Test that DeprecatedCredentialField handles unexpected input types gracefully."""
def test_dict_value_raises_validation_error(self):
"""Passing a dict instead of an integer should return a 400 validation error, not a 500 TypeError."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value({"username": "admin", "password": "secret"})
def test_ordered_dict_value_raises_validation_error(self):
"""Passing an OrderedDict should return a 400 validation error, not a 500 TypeError."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value(OrderedDict([("username", "admin")]))
def test_list_value_raises_validation_error(self):
"""Passing a list should return a 400 validation error, not a 500 TypeError."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value([1, 2, 3])
def test_string_value_raises_validation_error(self):
"""Passing a non-numeric string should return a 400 validation error."""
field = DeprecatedCredentialField()
with pytest.raises(ValidationError):
field.to_internal_value("not_a_number")
@mock.patch('awx.api.fields.Credential.objects')
def test_valid_integer_value_works(self, mock_cred_objects):
"""Passing a valid integer PK should work when the credential exists."""
mock_cred_objects.get.return_value = mock.MagicMock()
field = DeprecatedCredentialField()
assert field.to_internal_value(42) == 42
@mock.patch('awx.api.fields.Credential.objects')
def test_valid_string_integer_value_works(self, mock_cred_objects):
"""Passing a numeric string PK should work when the credential exists."""
mock_cred_objects.get.return_value = mock.MagicMock()
field = DeprecatedCredentialField()
assert field.to_internal_value("42") == 42

View File

@@ -1,3 +1,4 @@
import copy
import warnings import warnings
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@@ -8,6 +9,7 @@ from awx.api.schema import (
AuthenticatedSpectacularAPIView, AuthenticatedSpectacularAPIView,
AuthenticatedSpectacularSwaggerView, AuthenticatedSpectacularSwaggerView,
AuthenticatedSpectacularRedocView, AuthenticatedSpectacularRedocView,
filter_credential_type_schema,
) )
@@ -271,3 +273,152 @@ class TestAuthenticatedSchemaViews:
def test_authenticated_spectacular_redoc_view_requires_authentication(self): def test_authenticated_spectacular_redoc_view_requires_authentication(self):
"""Test that AuthenticatedSpectacularRedocView requires authentication.""" """Test that AuthenticatedSpectacularRedocView requires authentication."""
assert IsAuthenticated in AuthenticatedSpectacularRedocView.permission_classes assert IsAuthenticated in AuthenticatedSpectacularRedocView.permission_classes
class TestFilterCredentialTypeSchema:
"""Unit tests for filter_credential_type_schema postprocessing hook."""
def test_filters_both_schemas_correctly(self):
"""Test that both CredentialTypeRequest and PatchedCredentialTypeRequest schemas are filtered."""
result = {
'components': {
'schemas': {
'CredentialTypeRequest': {
'properties': {
'kind': {
'enum': [
'ssh',
'vault',
'net',
'scm',
'cloud',
'registry',
'token',
'insights',
'external',
'kubernetes',
'galaxy',
'cryptography',
None,
],
'type': 'string',
}
}
},
'PatchedCredentialTypeRequest': {
'properties': {
'kind': {
'enum': [
'ssh',
'vault',
'net',
'scm',
'cloud',
'registry',
'token',
'insights',
'external',
'kubernetes',
'galaxy',
'cryptography',
None,
],
'type': 'string',
}
}
},
}
}
}
returned = filter_credential_type_schema(result, None, None, None)
# POST/PUT schema: no None (required field)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['description'] == "* `cloud` - Cloud\\n* `net` - Network"
# PATCH schema: includes None (optional field)
assert result['components']['schemas']['PatchedCredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net', None]
assert result['components']['schemas']['PatchedCredentialTypeRequest']['properties']['kind']['description'] == "* `cloud` - Cloud\\n* `net` - Network"
# Other properties should be preserved
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['type'] == 'string'
# Function should return the result
assert returned is result
def test_handles_empty_result(self):
"""Test graceful handling when result dict is empty."""
result = {}
original = copy.deepcopy(result)
returned = filter_credential_type_schema(result, None, None, None)
assert result == original
assert returned is result
def test_handles_missing_enum(self):
"""Test that schemas without enum key are not modified."""
result = {'components': {'schemas': {'CredentialTypeRequest': {'properties': {'kind': {'type': 'string', 'description': 'Some description'}}}}}}
original = copy.deepcopy(result)
filter_credential_type_schema(result, None, None, None)
assert result == original
def test_filters_only_target_schemas(self):
"""Test that only CredentialTypeRequest schemas are modified, not others."""
result = {
'components': {
'schemas': {
'CredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'cloud', 'net', None]}}},
'OtherSchema': {'properties': {'kind': {'enum': ['option1', 'option2']}}},
}
}
}
other_schema_before = copy.deepcopy(result['components']['schemas']['OtherSchema'])
filter_credential_type_schema(result, None, None, None)
# CredentialTypeRequest should be filtered (no None for required field)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
# OtherSchema should be unchanged
assert result['components']['schemas']['OtherSchema'] == other_schema_before
def test_handles_only_one_schema_present(self):
"""Test that function works when only one target schema is present."""
result = {'components': {'schemas': {'CredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'cloud', 'net', None]}}}}}}
filter_credential_type_schema(result, None, None, None)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
def test_handles_missing_properties(self):
"""Test graceful handling when schema has no properties key."""
result = {'components': {'schemas': {'CredentialTypeRequest': {}}}}
original = copy.deepcopy(result)
filter_credential_type_schema(result, None, None, None)
assert result == original
def test_differentiates_required_vs_optional_fields(self):
"""Test that CredentialTypeRequest excludes None but PatchedCredentialTypeRequest includes it."""
result = {
'components': {
'schemas': {
'CredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'vault', 'net', 'scm', 'cloud', 'registry', None]}}},
'PatchedCredentialTypeRequest': {'properties': {'kind': {'enum': ['ssh', 'vault', 'net', 'scm', 'cloud', 'registry', None]}}},
}
}
}
filter_credential_type_schema(result, None, None, None)
# POST/PUT schema: no None (required field)
assert result['components']['schemas']['CredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net']
# PATCH schema: includes None (optional field)
assert result['components']['schemas']['PatchedCredentialTypeRequest']['properties']['kind']['enum'] == ['cloud', 'net', None]

View File

@@ -0,0 +1,92 @@
import io
import pytest
from django.core.management.base import CommandError
from awx.main.management.commands import dispatcherctl
@pytest.fixture(autouse=True)
def clear_dispatcher_env(monkeypatch, mocker):
monkeypatch.delenv('DISPATCHERD_CONFIG_FILE', raising=False)
mocker.patch.object(dispatcherctl.logging, 'basicConfig')
mocker.patch.object(dispatcherctl, 'connection', mocker.Mock(vendor='postgresql'))
def test_dispatcherctl_runs_control_with_generated_config(mocker):
command = dispatcherctl.Command()
command.stdout = io.StringIO()
data = {'foo': 'bar'}
mocker.patch.object(dispatcherctl, '_build_command_data_from_args', return_value=data)
dispatcher_setup = mocker.patch.object(dispatcherctl, 'dispatcher_setup')
config_data = {'setting': 'value'}
mocker.patch.object(dispatcherctl, 'get_dispatcherd_config', return_value=config_data)
control = mocker.Mock()
control.control_with_reply.return_value = [{'status': 'ok'}]
mocker.patch.object(dispatcherctl, 'get_control_from_settings', return_value=control)
mocker.patch.object(dispatcherctl.yaml, 'dump', return_value='payload\n')
command.handle(
command='running',
config=dispatcherctl.DEFAULT_CONFIG_FILE,
expected_replies=1,
log_level='INFO',
)
dispatcher_setup.assert_called_once_with(config_data)
control.control_with_reply.assert_called_once_with('running', data=data, expected_replies=1)
assert command.stdout.getvalue() == 'payload\n'
def test_dispatcherctl_rejects_custom_config_path():
command = dispatcherctl.Command()
command.stdout = io.StringIO()
with pytest.raises(CommandError):
command.handle(
command='running',
config='/tmp/dispatcher.yml',
expected_replies=1,
log_level='INFO',
)
def test_dispatcherctl_rejects_sqlite_db(mocker):
command = dispatcherctl.Command()
command.stdout = io.StringIO()
mocker.patch.object(dispatcherctl, 'connection', mocker.Mock(vendor='sqlite'))
with pytest.raises(CommandError, match='sqlite3'):
command.handle(
command='running',
config=dispatcherctl.DEFAULT_CONFIG_FILE,
expected_replies=1,
log_level='INFO',
)
def test_dispatcherctl_raises_when_replies_missing(mocker):
command = dispatcherctl.Command()
command.stdout = io.StringIO()
mocker.patch.object(dispatcherctl, '_build_command_data_from_args', return_value={})
mocker.patch.object(dispatcherctl, 'dispatcher_setup')
mocker.patch.object(dispatcherctl, 'get_dispatcherd_config', return_value={})
control = mocker.Mock()
control.control_with_reply.return_value = [{'status': 'ok'}]
mocker.patch.object(dispatcherctl, 'get_control_from_settings', return_value=control)
mocker.patch.object(dispatcherctl.yaml, 'dump', return_value='- status: ok\n')
with pytest.raises(CommandError):
command.handle(
command='running',
config=dispatcherctl.DEFAULT_CONFIG_FILE,
expected_replies=2,
log_level='INFO',
)
control.control_with_reply.assert_called_once_with('running', data={}, expected_replies=2)

View File

@@ -176,22 +176,22 @@ def test_display_survey_spec_encrypts_default(survey_spec_factory):
@pytest.mark.survey @pytest.mark.survey
@pytest.mark.parametrize( @pytest.mark.parametrize(
"question_type,default,min,max,expect_use,expect_value", "question_type,default,min,max,expect_valid,expect_use,expect_value",
[ [
("text", "", 0, 0, True, ''), # default used ("text", "", 0, 0, True, False, 'N/A'), # valid but empty default not sent for optional question
("text", "", 1, 0, False, 'N/A'), # value less than min length ("text", "", 1, 0, False, False, 'N/A'), # value less than min length
("password", "", 1, 0, False, 'N/A'), # passwords behave the same as text ("password", "", 1, 0, False, False, 'N/A'), # passwords behave the same as text
("multiplechoice", "", 0, 0, False, 'N/A'), # historical bug ("multiplechoice", "", 0, 0, False, False, 'N/A'), # historical bug
("multiplechoice", "zeb", 0, 0, False, 'N/A'), # zeb not in choices ("multiplechoice", "zeb", 0, 0, False, False, 'N/A'), # zeb not in choices
("multiplechoice", "coffee", 0, 0, True, 'coffee'), ("multiplechoice", "coffee", 0, 0, True, True, 'coffee'),
("multiselect", None, 0, 0, False, 'N/A'), # NOTE: Behavior is arguable, value of [] may be prefered ("multiselect", None, 0, 0, False, False, 'N/A'), # NOTE: Behavior is arguable, value of [] may be prefered
("multiselect", "", 0, 0, False, 'N/A'), ("multiselect", "", 0, 0, False, False, 'N/A'),
("multiselect", ["zeb"], 0, 0, False, 'N/A'), ("multiselect", ["zeb"], 0, 0, False, False, 'N/A'),
("multiselect", ["milk"], 0, 0, True, ["milk"]), ("multiselect", ["milk"], 0, 0, True, True, ["milk"]),
("multiselect", ["orange\nmilk"], 0, 0, False, 'N/A'), # historical bug ("multiselect", ["orange\nmilk"], 0, 0, False, False, 'N/A'), # historical bug
], ],
) )
def test_optional_survey_question_defaults(survey_spec_factory, question_type, default, min, max, expect_use, expect_value): def test_optional_survey_question_defaults(survey_spec_factory, question_type, default, min, max, expect_valid, expect_use, expect_value):
spec = survey_spec_factory( spec = survey_spec_factory(
[ [
{ {
@@ -208,7 +208,7 @@ def test_optional_survey_question_defaults(survey_spec_factory, question_type, d
jt = JobTemplate(name="test-jt", survey_spec=spec, survey_enabled=True) jt = JobTemplate(name="test-jt", survey_spec=spec, survey_enabled=True)
defaulted_extra_vars = jt._update_unified_job_kwargs({}, {}) defaulted_extra_vars = jt._update_unified_job_kwargs({}, {})
element = spec['spec'][0] element = spec['spec'][0]
if expect_use: if expect_valid:
assert jt._survey_element_validation(element, {element['variable']: element['default']}) == [] assert jt._survey_element_validation(element, {element['variable']: element['default']}) == []
else: else:
assert jt._survey_element_validation(element, {element['variable']: element['default']}) assert jt._survey_element_validation(element, {element['variable']: element['default']})
@@ -218,6 +218,28 @@ def test_optional_survey_question_defaults(survey_spec_factory, question_type, d
assert 'c' not in defaulted_extra_vars['extra_vars'] assert 'c' not in defaulted_extra_vars['extra_vars']
@pytest.mark.survey
def test_optional_survey_empty_default_with_runtime_extra_var(survey_spec_factory):
"""When a user explicitly provides an empty string at runtime for an optional
survey question, the variable should still be included in extra_vars."""
spec = survey_spec_factory(
[
{
"required": False,
"default": "",
"choices": "",
"variable": "c",
"min": 0,
"max": 0,
"type": "text",
},
]
)
jt = JobTemplate(name="test-jt", survey_spec=spec, survey_enabled=True)
defaulted_extra_vars = jt._update_unified_job_kwargs({}, {'extra_vars': json.dumps({'c': ''})})
assert json.loads(defaulted_extra_vars['extra_vars'])['c'] == ''
@pytest.mark.survey @pytest.mark.survey
@pytest.mark.parametrize( @pytest.mark.parametrize(
"question_type,default,maxlen,kwargs,expected", "question_type,default,maxlen,kwargs,expected",

View File

@@ -1,4 +1,3 @@
import pytest
from unittest import mock from unittest import mock
from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory
@@ -22,52 +21,6 @@ def test_unified_job_workflow_attributes():
assert job.workflow_job_id == 1 assert job.workflow_job_id == 1
def mock_on_commit(f):
f()
@pytest.fixture
def unified_job(mocker):
mocker.patch.object(UnifiedJob, 'can_cancel', return_value=True)
j = UnifiedJob()
j.status = 'pending'
j.cancel_flag = None
j.save = mocker.MagicMock()
j.websocket_emit_status = mocker.MagicMock()
j.fallback_cancel = mocker.MagicMock()
return j
def test_cancel(unified_job):
with mock.patch('awx.main.models.unified_jobs.connection.on_commit', wraps=mock_on_commit):
unified_job.cancel()
assert unified_job.cancel_flag is True
assert unified_job.status == 'canceled'
assert unified_job.job_explanation == ''
# Note: the websocket emit status check is just reflecting the state of the current code.
# Some more thought may want to go into only emitting canceled if/when the job record
# status is changed to canceled. Unlike, currently, where it's emitted unconditionally.
unified_job.websocket_emit_status.assert_called_with("canceled")
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
((), {'update_fields': ['cancel_flag', 'start_args']}),
((), {'update_fields': ['status']}),
]
def test_cancel_job_explanation(unified_job):
job_explanation = 'giggity giggity'
with mock.patch('awx.main.models.unified_jobs.connection.on_commit'):
unified_job.cancel(job_explanation=job_explanation)
assert unified_job.job_explanation == job_explanation
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
((), {'update_fields': ['cancel_flag', 'start_args', 'job_explanation']}),
((), {'update_fields': ['status']}),
]
def test_organization_copy_to_jobs(): def test_organization_copy_to_jobs():
""" """
All unified job types should infer their organization from their template organization All unified job types should infer their organization from their template organization

View File

@@ -226,3 +226,140 @@ def test_send_messages_with_additional_headers():
allow_redirects=False, allow_redirects=False,
) )
assert sent_messages == 1 assert sent_messages == 1
def test_send_messages_with_redirects_ok():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock:
# First two calls return redirects, third call returns 200
requests_mock.post.side_effect = [
mock.Mock(status_code=301, headers={"Location": "http://redirect1.com"}),
mock.Mock(status_code=307, headers={"Location": "http://redirect2.com"}),
mock.Mock(status_code=200),
]
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
assert requests_mock.post.call_count == 3
requests_mock.post.assert_called_with(
url='http://redirect2.com',
auth=None,
data=json.dumps({'text': 'test body'}, ensure_ascii=False).encode('utf-8'),
headers={'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'},
verify=True,
allow_redirects=False,
)
assert sent_messages == 1
def test_send_messages_with_redirects_blank():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock, mock.patch('awx.main.notifications.webhook_backend.logger') as logger_mock:
# First call returns a redirect with Location header, second call returns 301 but NO Location header
requests_mock.post.side_effect = [
mock.Mock(status_code=301, headers={"Location": "http://redirect1.com"}),
mock.Mock(status_code=301, headers={}), # 301 with no Location header
]
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None, fail_silently=True)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
# Should make 2 requests (initial + 1 redirect attempt)
assert requests_mock.post.call_count == 2
# The error message should be logged
logger_mock.error.assert_called_once()
error_call_args = logger_mock.error.call_args[0][0]
assert "redirect to a blank URL" in error_call_args
assert sent_messages == 0
def test_send_messages_with_redirects_max_retries_exceeded():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock, mock.patch('awx.main.notifications.webhook_backend.logger') as logger_mock:
# Return MAX_RETRIES (5) redirect responses to exceed the retry limit
requests_mock.post.side_effect = [
mock.Mock(status_code=301, headers={"Location": "http://redirect1.com"}),
mock.Mock(status_code=301, headers={"Location": "http://redirect2.com"}),
mock.Mock(status_code=307, headers={"Location": "http://redirect3.com"}),
mock.Mock(status_code=301, headers={"Location": "http://redirect4.com"}),
mock.Mock(status_code=307, headers={"Location": "http://redirect5.com"}),
]
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None, fail_silently=True)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
# Should make exactly 5 requests (MAX_RETRIES)
assert requests_mock.post.call_count == 5
# The error message should be logged for exceeding max retries
logger_mock.error.assert_called_once()
error_call_args = logger_mock.error.call_args[0][0]
assert "max number of retries" in error_call_args
assert "[5]" in error_call_args
assert sent_messages == 0
def test_send_messages_with_error_status_code():
with mock.patch('awx.main.notifications.webhook_backend.requests') as requests_mock, mock.patch(
'awx.main.notifications.webhook_backend.get_awx_http_client_headers'
) as version_mock, mock.patch('awx.main.notifications.webhook_backend.logger') as logger_mock:
# Return a 404 error status code
requests_mock.post.return_value = mock.Mock(status_code=404)
version_mock.return_value = {'Content-Type': 'application/json', 'User-Agent': 'AWX 0.0.1.dev (open)'}
backend = webhook_backend.WebhookBackend('POST', None, fail_silently=True)
message = EmailMessage(
'test subject',
{'text': 'test body'},
[],
[
'http://example.com',
],
)
sent_messages = backend.send_messages(
[
message,
]
)
# Should make exactly 1 request
assert requests_mock.post.call_count == 1
# The error message should be logged
logger_mock.error.assert_called_once()
error_call_args = logger_mock.error.call_args[0][0]
assert "Error sending webhook notification: 404" in error_call_args
assert sent_messages == 0

View File

@@ -1,20 +1,19 @@
import pytest import pytest
from django.conf import settings from django.conf import settings
from datetime import timedelta
@pytest.mark.parametrize( @pytest.mark.parametrize(
"job_name,function_path", "task_name",
[ [
('tower_scheduler', 'awx.main.tasks.system.awx_periodic_scheduler'), 'awx.main.tasks.system.awx_periodic_scheduler',
], ],
) )
def test_CELERYBEAT_SCHEDULE(mocker, job_name, function_path): def test_DISPATCHER_SCHEDULE(mocker, task_name):
assert job_name in settings.CELERYBEAT_SCHEDULE assert task_name in settings.DISPATCHER_SCHEDULE
assert 'schedule' in settings.CELERYBEAT_SCHEDULE[job_name] assert 'schedule' in settings.DISPATCHER_SCHEDULE[task_name]
assert type(settings.CELERYBEAT_SCHEDULE[job_name]['schedule']) is timedelta assert type(settings.DISPATCHER_SCHEDULE[task_name]['schedule']) in (int, float)
assert settings.CELERYBEAT_SCHEDULE[job_name]['task'] == function_path assert settings.DISPATCHER_SCHEDULE[task_name]['task'] == task_name
# Ensures that the function exists # Ensures that the function exists
mocker.patch(function_path) mocker.patch(task_name)

View File

@@ -18,8 +18,17 @@ from awx.main.models import (
Job, Job,
Organization, Organization,
Project, Project,
JobTemplate,
UnifiedJobTemplate,
InstanceGroup,
ExecutionEnvironment,
ProjectUpdate,
InventoryUpdate,
InventorySource,
AdHocCommand,
) )
from awx.main.tasks import jobs from awx.main.tasks import jobs
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
@pytest.fixture @pytest.fixture
@@ -188,3 +197,233 @@ def test_invalid_host_facts(mock_facts_settings, bulk_update_sorted_by_id, priva
with pytest.raises(pytest.fail.Exception): with pytest.raises(pytest.fail.Exception):
if failures: if failures:
pytest.fail(f" {len(failures)} facts cleared failures : {','.join(failures)}") pytest.fail(f" {len(failures)} facts cleared failures : {','.join(failures)}")
@pytest.mark.parametrize(
"job_attrs,expected_claims",
[
(
{
'id': 100,
'name': 'Test Job',
'job_type': 'run',
'launch_type': 'manual',
'playbook': 'site.yml',
'organization': Organization(id=1, name='Test Org'),
'inventory': Inventory(id=2, name='Test Inventory'),
'project': Project(id=3, name='Test Project'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'job_template': JobTemplate(id=5, name='Test Job Template'),
'unified_job_template': UnifiedJobTemplate(pk=6, id=6, name='Test Unified Job Template'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 100,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Test Job',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: 'site.yml',
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_INVENTORY_NAME: 'Test Inventory',
AutomationControllerJobScope.CLAIM_INVENTORY_ID: 2,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_PROJECT_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_PROJECT_ID: 3,
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: 'Test Job Template',
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: 5,
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME: 'Test Unified Job Template',
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID: 6,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{'id': 100, 'name': 'Test', 'job_type': 'run', 'launch_type': 'manual', 'organization': Organization(id=1, name='')},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 100,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Test',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: '',
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: '',
},
),
],
)
def test_populate_claims_for_workload(job_attrs, expected_claims):
job = Job()
for attr, value in job_attrs.items():
setattr(job, attr, value)
claims = jobs.populate_claims_for_workload(job)
assert claims == expected_claims
@pytest.mark.parametrize(
"workload_attrs,expected_claims",
[
(
{
'id': 200,
'name': 'Git Sync',
'job_type': 'check',
'launch_type': 'sync',
'organization': Organization(id=1, name='Test Org'),
'project': Project(pk=3, id=3, name='Test Project'),
'unified_job_template': Project(pk=3, id=3, name='Test Project'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 200,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Git Sync',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'check',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'sync',
AutomationControllerJobScope.CLAIM_LAUNCHED_BY_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_LAUNCHED_BY_ID: 3,
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_PROJECT_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_PROJECT_ID: 3,
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME: 'Test Project',
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID: 3,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{
'id': 201,
'name': 'Minimal Project Update',
'job_type': 'run',
'launch_type': 'manual',
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 201,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Minimal Project Update',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
},
),
],
)
def test_populate_claims_for_project_update(workload_attrs, expected_claims):
project_update = ProjectUpdate()
for attr, value in workload_attrs.items():
setattr(project_update, attr, value)
claims = jobs.populate_claims_for_workload(project_update)
assert claims == expected_claims
@pytest.mark.parametrize(
"workload_attrs,expected_claims",
[
(
{
'id': 300,
'name': 'AWS Sync',
'launch_type': 'scheduled',
'organization': Organization(id=1, name='Test Org'),
'inventory': Inventory(id=2, name='AWS Inventory'),
'unified_job_template': InventorySource(pk=8, id=8, name='AWS Source'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 300,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'AWS Sync',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'scheduled',
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_INVENTORY_NAME: 'AWS Inventory',
AutomationControllerJobScope.CLAIM_INVENTORY_ID: 2,
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_NAME: 'AWS Source',
AutomationControllerJobScope.CLAIM_UNIFIED_JOB_TEMPLATE_ID: 8,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{
'id': 301,
'name': 'Minimal Inventory Update',
'launch_type': 'manual',
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 301,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Minimal Inventory Update',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
},
),
],
)
def test_populate_claims_for_inventory_update(workload_attrs, expected_claims):
inventory_update = InventoryUpdate()
for attr, value in workload_attrs.items():
setattr(inventory_update, attr, value)
claims = jobs.populate_claims_for_workload(inventory_update)
assert claims == expected_claims
@pytest.mark.parametrize(
"workload_attrs,expected_claims",
[
(
{
'id': 400,
'name': 'Ping All Hosts',
'job_type': 'run',
'launch_type': 'manual',
'organization': Organization(id=1, name='Test Org'),
'inventory': Inventory(id=2, name='Test Inventory'),
'execution_environment': ExecutionEnvironment(id=4, name='Test EE'),
'instance_group': InstanceGroup(id=7, name='Test Instance Group'),
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 400,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Ping All Hosts',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: 'Test Org',
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: 1,
AutomationControllerJobScope.CLAIM_INVENTORY_NAME: 'Test Inventory',
AutomationControllerJobScope.CLAIM_INVENTORY_ID: 2,
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_NAME: 'Test EE',
AutomationControllerJobScope.CLAIM_EXECUTION_ENVIRONMENT_ID: 4,
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_NAME: 'Test Instance Group',
AutomationControllerJobScope.CLAIM_INSTANCE_GROUP_ID: 7,
},
),
(
{
'id': 401,
'name': 'Minimal Ad Hoc',
'job_type': 'run',
'launch_type': 'manual',
},
{
AutomationControllerJobScope.CLAIM_JOB_ID: 401,
AutomationControllerJobScope.CLAIM_JOB_NAME: 'Minimal Ad Hoc',
AutomationControllerJobScope.CLAIM_JOB_TYPE: 'run',
AutomationControllerJobScope.CLAIM_LAUNCH_TYPE: 'manual',
},
),
],
)
def test_populate_claims_for_adhoc_command(workload_attrs, expected_claims):
adhoc_command = AdHocCommand()
for attr, value in workload_attrs.items():
setattr(adhoc_command, attr, value)
claims = jobs.populate_claims_for_workload(adhoc_command)
assert claims == expected_claims

View File

@@ -12,6 +12,10 @@ def pytest_sigterm():
pytest_sigterm.called_count += 1 pytest_sigterm.called_count += 1
def pytest_sigusr1():
pytest_sigusr1.called_count += 1
def tmp_signals_for_test(func): def tmp_signals_for_test(func):
""" """
When we run our internal signal handlers, it will call the original signal When we run our internal signal handlers, it will call the original signal
@@ -26,13 +30,17 @@ def tmp_signals_for_test(func):
def wrapper(): def wrapper():
original_sigterm = signal.getsignal(signal.SIGTERM) original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT) original_sigint = signal.getsignal(signal.SIGINT)
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
signal.signal(signal.SIGTERM, pytest_sigterm) signal.signal(signal.SIGTERM, pytest_sigterm)
signal.signal(signal.SIGINT, pytest_sigint) signal.signal(signal.SIGINT, pytest_sigint)
signal.signal(signal.SIGUSR1, pytest_sigusr1)
pytest_sigterm.called_count = 0 pytest_sigterm.called_count = 0
pytest_sigint.called_count = 0 pytest_sigint.called_count = 0
pytest_sigusr1.called_count = 0
func() func()
signal.signal(signal.SIGTERM, original_sigterm) signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint) signal.signal(signal.SIGINT, original_sigint)
signal.signal(signal.SIGUSR1, original_sigusr1)
return wrapper return wrapper
@@ -58,11 +66,13 @@ def test_outer_inner_signal_handling():
assert signal_callback() is False assert signal_callback() is False
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1() f1()
assert signal_callback() is False assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 1 assert pytest_sigterm.called_count == 1
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
@tmp_signals_for_test @tmp_signals_for_test
@@ -87,8 +97,31 @@ def test_inner_outer_signal_handling():
assert signal_callback() is False assert signal_callback() is False
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1() f1()
assert signal_callback() is False assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 1 assert pytest_sigint.called_count == 1
assert pytest_sigusr1.called_count == 0
@tmp_signals_for_test
def test_sigusr1_signal_handling():
@with_signal_handling
def f1():
assert signal_callback() is False
signal_state.set_signal_flag(for_signal=signal.SIGUSR1)
assert signal_callback()
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
assert signal_callback() is False
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1()
assert signal_callback() is False
assert signal.getsignal(signal.SIGUSR1) is original_sigusr1
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 1

View File

@@ -8,9 +8,7 @@ LOCAL_SETTINGS = (
'CACHES', 'CACHES',
'DEBUG', 'DEBUG',
'NAMED_URL_GRAPH', 'NAMED_URL_GRAPH',
'DISPATCHER_MOCK_PUBLISH',
# Platform flags are managed by the platform flags system and have environment-specific defaults # Platform flags are managed by the platform flags system and have environment-specific defaults
'FEATURE_DISPATCHERD_ENABLED',
'FEATURE_INDIRECT_NODE_COUNTING_ENABLED', 'FEATURE_INDIRECT_NODE_COUNTING_ENABLED',
) )
@@ -87,12 +85,9 @@ def test_development_defaults_feature_flags(monkeypatch):
spec.loader.exec_module(development_defaults) spec.loader.exec_module(development_defaults)
# Also import through the development settings to ensure both paths are tested # Also import through the development settings to ensure both paths are tested
from awx.settings.development import FEATURE_INDIRECT_NODE_COUNTING_ENABLED, FEATURE_DISPATCHERD_ENABLED from awx.settings.development import FEATURE_INDIRECT_NODE_COUNTING_ENABLED
# Verify the feature flags are set correctly in both the module and settings # Verify the feature flags are set correctly in both the module and settings
assert hasattr(development_defaults, 'FEATURE_INDIRECT_NODE_COUNTING_ENABLED') assert hasattr(development_defaults, 'FEATURE_INDIRECT_NODE_COUNTING_ENABLED')
assert development_defaults.FEATURE_INDIRECT_NODE_COUNTING_ENABLED is True assert development_defaults.FEATURE_INDIRECT_NODE_COUNTING_ENABLED is True
assert hasattr(development_defaults, 'FEATURE_DISPATCHERD_ENABLED')
assert development_defaults.FEATURE_DISPATCHERD_ENABLED is True
assert FEATURE_INDIRECT_NODE_COUNTING_ENABLED is True assert FEATURE_INDIRECT_NODE_COUNTING_ENABLED is True
assert FEATURE_DISPATCHERD_ENABLED is True

View File

@@ -132,6 +132,25 @@ def test_cert_with_key():
assert not pem_objects[1]['key_enc'] assert not pem_objects[1]['key_enc']
def test_ssh_key_with_whitespace():
# Test that SSH keys with leading/trailing whitespace/newlines are properly sanitized
# This addresses issue #14219 where copy-paste can introduce hidden newlines
valid_key_with_whitespace = "\n\n" + TEST_SSH_KEY_DATA + "\n\n"
pem_objects = validate_ssh_private_key(valid_key_with_whitespace)
assert pem_objects[0]['key_type'] == 'rsa'
assert not pem_objects[0]['key_enc']
# Test with just leading whitespace
valid_key_leading = "\n\n\n" + TEST_SSH_KEY_DATA
pem_objects = validate_ssh_private_key(valid_key_leading)
assert pem_objects[0]['key_type'] == 'rsa'
# Test with just trailing whitespace
valid_key_trailing = TEST_SSH_KEY_DATA + "\n\n\n"
pem_objects = validate_ssh_private_key(valid_key_trailing)
assert pem_objects[0]['key_type'] == 'rsa'
@pytest.mark.parametrize( @pytest.mark.parametrize(
"var_str", "var_str",
[ [

View File

@@ -0,0 +1,207 @@
# Copyright (c) 2024 Ansible, Inc.
# All Rights Reserved.
from unittest import mock
from awx.main.utils.proxy import get_first_remote_host_from_headers, is_proxy_in_headers
class TestGetFirstRemoteHostFromHeaders:
"""Tests for get_first_remote_host_from_headers function."""
def _make_mock_request(self, environ):
"""Create a mock request with the given environ dict."""
request = mock.MagicMock()
request.environ = environ
return request
def test_single_value_headers(self):
"""Test extraction from headers with single values (no commas)."""
request = self._make_mock_request(
{
"REMOTE_ADDR": "192.168.1.1",
"REMOTE_HOST": "client.example.com",
}
)
headers = ["REMOTE_ADDR", "REMOTE_HOST"]
result = get_first_remote_host_from_headers(request, headers)
assert result == {"192.168.1.1", "client.example.com"}
def test_comma_separated_only_first_entry(self):
"""Test that only the first entry is extracted from comma-separated values."""
request = self._make_mock_request(
{
"HTTP_X_FORWARDED_FOR": "10.0.0.1, 192.168.1.1, 172.16.0.1",
}
)
headers = ["HTTP_X_FORWARDED_FOR"]
result = get_first_remote_host_from_headers(request, headers)
# Only the first IP should be included
assert result == {"10.0.0.1"}
# Subsequent IPs should NOT be included
assert "192.168.1.1" not in result
assert "172.16.0.1" not in result
def test_comma_separated_with_whitespace(self):
"""Test that whitespace is properly stripped from first entry."""
request = self._make_mock_request(
{
"HTTP_X_FORWARDED_FOR": " 10.0.0.1 , 192.168.1.1",
}
)
headers = ["HTTP_X_FORWARDED_FOR"]
result = get_first_remote_host_from_headers(request, headers)
assert result == {"10.0.0.1"}
def test_multiple_headers_with_comma_separated(self):
"""Test multiple headers where some have comma-separated values."""
request = self._make_mock_request(
{
"HTTP_X_FORWARDED_FOR": "client.example.com, proxy1.example.com, proxy2.example.com",
"REMOTE_ADDR": "172.16.0.1",
"REMOTE_HOST": "proxy2.example.com",
}
)
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR", "REMOTE_HOST"]
result = get_first_remote_host_from_headers(request, headers)
# Should have first entry from X-Forwarded-For plus the single values from other headers
assert result == {"client.example.com", "172.16.0.1", "proxy2.example.com"}
# Should NOT have subsequent entries from X-Forwarded-For
assert "proxy1.example.com" not in result
def test_empty_header_value(self):
"""Test handling of empty header values."""
request = self._make_mock_request(
{
"HTTP_X_FORWARDED_FOR": "",
"REMOTE_ADDR": "192.168.1.1",
}
)
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR"]
result = get_first_remote_host_from_headers(request, headers)
assert result == {"192.168.1.1"}
def test_missing_header(self):
"""Test handling of headers that don't exist in environ."""
request = self._make_mock_request(
{
"REMOTE_ADDR": "192.168.1.1",
}
)
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR", "REMOTE_HOST"]
result = get_first_remote_host_from_headers(request, headers)
assert result == {"192.168.1.1"}
def test_empty_headers_list(self):
"""Test with no headers specified."""
request = self._make_mock_request(
{
"REMOTE_ADDR": "192.168.1.1",
}
)
headers = []
result = get_first_remote_host_from_headers(request, headers)
assert result == set()
def test_whitespace_only_first_entry(self):
"""Test handling when first entry is whitespace only."""
request = self._make_mock_request(
{
"HTTP_X_FORWARDED_FOR": " , 192.168.1.1",
}
)
headers = ["HTTP_X_FORWARDED_FOR"]
result = get_first_remote_host_from_headers(request, headers)
# Empty/whitespace first entry should be skipped
assert result == set()
def test_single_entry_with_trailing_comma(self):
"""Test single entry that happens to have a trailing comma."""
request = self._make_mock_request(
{
"HTTP_X_FORWARDED_FOR": "10.0.0.1,",
}
)
headers = ["HTTP_X_FORWARDED_FOR"]
result = get_first_remote_host_from_headers(request, headers)
assert result == {"10.0.0.1"}
class TestIsProxyInHeaders:
"""Tests for is_proxy_in_headers function."""
def _make_mock_request(self, environ):
"""Create a mock request with the given environ dict."""
request = mock.MagicMock()
request.environ = environ
return request
def test_proxy_found_in_single_value(self):
"""Test proxy detection in single-value header."""
request = self._make_mock_request(
{
"REMOTE_ADDR": "192.168.1.1",
}
)
result = is_proxy_in_headers(request, ["192.168.1.1"], ["REMOTE_ADDR"])
assert result is True
def test_proxy_found_in_comma_separated(self):
"""Test proxy detection in comma-separated header value."""
request = self._make_mock_request(
{
"HTTP_X_FORWARDED_FOR": "10.0.0.1, 192.168.1.1, 172.16.0.1",
}
)
result = is_proxy_in_headers(request, ["192.168.1.1"], ["HTTP_X_FORWARDED_FOR"])
assert result is True
def test_proxy_not_found(self):
"""Test when proxy is not in any header."""
request = self._make_mock_request(
{
"REMOTE_ADDR": "10.0.0.1",
}
)
result = is_proxy_in_headers(request, ["192.168.1.1"], ["REMOTE_ADDR"])
assert result is False
def test_multiple_proxies_one_match(self):
"""Test with multiple allowed proxies, one matches."""
request = self._make_mock_request(
{
"REMOTE_HOST": "proxy.example.com",
}
)
result = is_proxy_in_headers(
request,
["proxy1.example.com", "proxy.example.com", "proxy2.example.com"],
["REMOTE_HOST"],
)
assert result is True

View File

@@ -43,6 +43,9 @@ from django.apps import apps
# AWX # AWX
from awx.conf.license import get_license from awx.conf.license import get_license
# ansible-runner
from ansible_runner.utils.capacity import get_mem_in_bytes, get_cpu_count
logger = logging.getLogger('awx.main.utils') logger = logging.getLogger('awx.main.utils')
__all__ = [ __all__ = [
@@ -997,9 +1000,15 @@ def getattrd(obj, name, default=NoDefaultProvided):
raise raise
def getattr_dne(obj, name, notfound=ObjectDoesNotExist): empty = object()
def getattr_dne(obj, name, default=empty, notfound=ObjectDoesNotExist):
try: try:
return getattr(obj, name) if default is empty:
return getattr(obj, name)
else:
return getattr(obj, name, default)
except notfound: except notfound:
return None return None
@@ -1220,3 +1229,38 @@ def unified_job_class_to_event_table_name(job_class):
def load_all_entry_points_for(entry_point_subsections: list[str], /) -> dict[str, EntryPoint]: def load_all_entry_points_for(entry_point_subsections: list[str], /) -> dict[str, EntryPoint]:
return {ep.name: ep for entry_point_category in entry_point_subsections for ep in entry_points(group=f'awx_plugins.{entry_point_category}')} return {ep.name: ep for entry_point_category in entry_point_subsections for ep in entry_points(group=f'awx_plugins.{entry_point_category}')}
def get_auto_max_workers():
"""Method we normally rely on to get max_workers
Uses almost same logic as Instance.local_health_check
The important thing is to be MORE than Instance.capacity
so that the task-manager does not over-schedule this node
Ideally we would just use the capacity from the database plus reserve workers,
but this poses some bootstrap problems where OCP task containers
register themselves after startup
"""
# Get memory from ansible-runner
total_memory_gb = get_mem_in_bytes()
# This may replace memory calculation with a user override
corrected_memory = get_corrected_memory(total_memory_gb)
# Get same number as max forks based on memory, this function takes memory as bytes
mem_capacity = get_mem_effective_capacity(corrected_memory, is_control_node=True)
# Follow same process for CPU capacity constraint
cpu_count = get_cpu_count()
corrected_cpu = get_corrected_cpu(cpu_count)
cpu_capacity = get_cpu_effective_capacity(corrected_cpu, is_control_node=True)
# Here is what is different from health checks,
auto_max = max(mem_capacity, cpu_capacity)
# add magic number of extra workers to ensure
# we have a few extra workers to run the heartbeat
auto_max += 7
return auto_max

View File

@@ -4,9 +4,9 @@ import tempfile
import urllib.parse as urlparse import urllib.parse as urlparse
from django.conf import settings from django.conf import settings
from dispatcherd.publish import task
from awx.main.utils.reload import supervisor_service_command from awx.main.utils.reload import supervisor_service_command
from awx.main.dispatch.publish import task as task_awx
def construct_rsyslog_conf_template(settings=settings): def construct_rsyslog_conf_template(settings=settings):
@@ -139,7 +139,7 @@ def construct_rsyslog_conf_template(settings=settings):
return tmpl return tmpl
@task_awx(queue='rsyslog_configurer', timeout=600, on_duplicate='queue_one') @task(queue='rsyslog_configurer', timeout=600, on_duplicate='queue_one')
def reconfigure_rsyslog(): def reconfigure_rsyslog():
tmpl = construct_rsyslog_conf_template() tmpl = construct_rsyslog_conf_template()
# Write config to a temp file then move it to preserve atomicity # Write config to a temp file then move it to preserve atomicity

View File

@@ -257,8 +257,7 @@ class LogstashFormatter(LogstashFormatterBase):
return fields return fields
def format(self, record): def format(self, record):
stamp = datetime.utcfromtimestamp(record.created) stamp = datetime.fromtimestamp(record.created, tz=tzutc())
stamp = stamp.replace(tzinfo=tzutc())
message = { message = {
# Field not included, but exist in related logs # Field not included, but exist in related logs
# 'path': record.pathname # 'path': record.pathname

View File

@@ -45,3 +45,38 @@ def delete_headers_starting_with_http(request: Request, headers: list[str]):
for header in headers: for header in headers:
if header.startswith('HTTP_'): if header.startswith('HTTP_'):
request.environ.pop(header, None) request.environ.pop(header, None)
def get_first_remote_host_from_headers(request: Request, headers: list[str]) -> set[str]:
"""
Extract remote host addresses from headers, considering only the first entry
in comma-separated values.
For headers like X-Forwarded-For that may contain multiple IPs (e.g., "client, proxy1, proxy2"),
only the first entry (the original client) is considered.
Example:
request.environ = {
"HTTP_X_FORWARDED_FOR": "10.0.0.1, 192.168.1.1, 172.16.0.1",
"REMOTE_ADDR": "192.168.1.1",
"REMOTE_HOST": "proxy.example.com"
}
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR", "REMOTE_HOST"]
Returns: {"10.0.0.1", "192.168.1.1", "proxy.example.com"}
(Only the first IP "10.0.0.1" from X-Forwarded-For, not the full chain)
request: The DRF/Django request. request.environ dict will be used for extracting hosts
headers: A list of header keys to check for remote host values
"""
remote_hosts = set()
for header in headers:
header_value = request.environ.get(header, '')
if header_value:
# Only take the first entry if comma-separated
first_value = header_value.split(',')[0].strip()
if first_value:
remote_hosts.add(first_value)
return remote_hosts

View File

@@ -181,6 +181,8 @@ def validate_ssh_private_key(data):
certificates; should handle any valid options for ssh_private_key on a certificates; should handle any valid options for ssh_private_key on a
credential. credential.
""" """
# Strip leading and trailing whitespace/newlines to handle common copy-paste issues
data = data.strip()
return validate_pem(data, min_keys=1) return validate_pem(data, min_keys=1)

View File

@@ -94,7 +94,7 @@ class WebsocketRelayConnection:
except asyncio.CancelledError: except asyncio.CancelledError:
# TODO: Check if connected and disconnect # TODO: Check if connected and disconnect
# Possibly use run_until_complete() if disconnect is async # Possibly use run_until_complete() if disconnect is async
logger.warning(f"Connection from {self.name} to {self.remote_host} cancelled.") logger.warning(f"Connection from {self.name} to {self.remote_host} canceled.")
except client_exceptions.ClientConnectorError as e: except client_exceptions.ClientConnectorError as e:
logger.warning(f"Connection from {self.name} to {self.remote_host} failed: '{e}'.", exc_info=True) logger.warning(f"Connection from {self.name} to {self.remote_host} failed: '{e}'.", exc_info=True)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@@ -291,7 +291,7 @@ class WebSocketRelayManager(object):
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"Tried to cancel relay connection for {hostname} but it timed out during cleanup.") logger.warning(f"Tried to cancel relay connection for {hostname} but it timed out during cleanup.")
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle the case where the task was already cancelled by the time we got here. # Handle the case where the task was already canceled by the time we got here.
pass pass
del self.relay_connections[hostname] del self.relay_connections[hostname]

View File

@@ -83,7 +83,6 @@ class ActionModule(ActionBase):
password = self._task.args.get('password', None) password = self._task.args.get('password', None)
client_id = self._task.args.get('client_id', None) client_id = self._task.args.get('client_id', None)
client_secret = self._task.args.get('client_secret', None) client_secret = self._task.args.get('client_secret', None)
oidc_endpoint = self._task.args.get('oidc_endpoint', DEFAULT_OIDC_ENDPOINT)
session.headers.update( session.headers.update(
{ {
@@ -93,7 +92,7 @@ class ActionModule(ActionBase):
) )
if authentication == 'service_account' or (client_id and client_secret): if authentication == 'service_account' or (client_id and client_secret):
data = self._obtain_auth_token(oidc_endpoint, client_id, client_secret) data = self._obtain_auth_token(DEFAULT_OIDC_ENDPOINT, client_id, client_secret)
if 'token' not in data: if 'token' not in data:
result['failed'] = data['failed'] result['failed'] = data['failed']
result['msg'] = data['msg'] result['msg'] = data['msg']

View File

@@ -7,7 +7,6 @@ import os
import re # noqa import re # noqa
import tempfile import tempfile
import socket import socket
from datetime import timedelta
DEBUG = True DEBUG = True
SQL_DEBUG = DEBUG SQL_DEBUG = DEBUG
@@ -416,49 +415,34 @@ EXECUTION_NODE_REMEDIATION_CHECKS = 60 * 30 # once every 30 minutes check if an
# Amount of time dispatcher will try to reconnect to database for jobs and consuming new work # Amount of time dispatcher will try to reconnect to database for jobs and consuming new work
DISPATCHER_DB_DOWNTIME_TOLERANCE = 40 DISPATCHER_DB_DOWNTIME_TOLERANCE = 40
# If you set this, nothing will ever be sent to pg_notify
# this is not practical to use, although periodic schedules may still run slugish but functional tasks
# sqlite3 based tests will use this
DISPATCHER_MOCK_PUBLISH = False
BROKER_URL = 'unix:///var/run/redis/redis.sock' BROKER_URL = 'unix:///var/run/redis/redis.sock'
REDIS_RETRY_COUNT = 3 # Number of retries for Redis connection errors REDIS_RETRY_COUNT = 3 # Number of retries for Redis connection errors
REDIS_BACKOFF_CAP = 1.0 # Maximum backoff delay in seconds for Redis retries REDIS_BACKOFF_CAP = 1.0 # Maximum backoff delay in seconds for Redis retries
REDIS_BACKOFF_BASE = 0.5 # Base for exponential backoff calculation for Redis retries REDIS_BACKOFF_BASE = 0.5 # Base for exponential backoff calculation for Redis retries
CELERYBEAT_SCHEDULE = {
'tower_scheduler': {'task': 'awx.main.tasks.system.awx_periodic_scheduler', 'schedule': timedelta(seconds=30), 'options': {'expires': 20}}, DISPATCHER_SCHEDULE = {
'cluster_heartbeat': { 'awx.main.tasks.system.awx_periodic_scheduler': {'task': 'awx.main.tasks.system.awx_periodic_scheduler', 'schedule': 30, 'options': {'expires': 20}},
'awx.main.tasks.system.cluster_node_heartbeat': {
'task': 'awx.main.tasks.system.cluster_node_heartbeat', 'task': 'awx.main.tasks.system.cluster_node_heartbeat',
'schedule': timedelta(seconds=CLUSTER_NODE_HEARTBEAT_PERIOD), 'schedule': CLUSTER_NODE_HEARTBEAT_PERIOD,
'options': {'expires': 50}, 'options': {'expires': 50},
}, },
'gather_analytics': {'task': 'awx.main.tasks.system.gather_analytics', 'schedule': timedelta(minutes=5)}, 'awx.main.tasks.system.gather_analytics': {'task': 'awx.main.tasks.system.gather_analytics', 'schedule': 300},
'task_manager': {'task': 'awx.main.scheduler.tasks.task_manager', 'schedule': timedelta(seconds=20), 'options': {'expires': 20}}, 'awx.main.scheduler.tasks.task_manager': {'task': 'awx.main.scheduler.tasks.task_manager', 'schedule': 20, 'options': {'expires': 20}},
'dependency_manager': {'task': 'awx.main.scheduler.tasks.dependency_manager', 'schedule': timedelta(seconds=20), 'options': {'expires': 20}}, 'awx.main.scheduler.tasks.dependency_manager': {'task': 'awx.main.scheduler.tasks.dependency_manager', 'schedule': 20, 'options': {'expires': 20}},
'k8s_reaper': {'task': 'awx.main.tasks.system.awx_k8s_reaper', 'schedule': timedelta(seconds=60), 'options': {'expires': 50}}, 'awx.main.tasks.system.awx_k8s_reaper': {'task': 'awx.main.tasks.system.awx_k8s_reaper', 'schedule': 60, 'options': {'expires': 50}},
'receptor_reaper': {'task': 'awx.main.tasks.system.awx_receptor_workunit_reaper', 'schedule': timedelta(seconds=60)}, 'awx.main.tasks.system.awx_receptor_workunit_reaper': {'task': 'awx.main.tasks.system.awx_receptor_workunit_reaper', 'schedule': 60},
'send_subsystem_metrics': {'task': 'awx.main.analytics.analytics_tasks.send_subsystem_metrics', 'schedule': timedelta(seconds=20)}, 'awx.main.analytics.analytics_tasks.send_subsystem_metrics': {'task': 'awx.main.analytics.analytics_tasks.send_subsystem_metrics', 'schedule': 20},
'cleanup_images': {'task': 'awx.main.tasks.system.cleanup_images_and_files', 'schedule': timedelta(hours=3)}, 'awx.main.tasks.system.cleanup_images_and_files': {'task': 'awx.main.tasks.system.cleanup_images_and_files', 'schedule': 10800},
'cleanup_host_metrics': {'task': 'awx.main.tasks.host_metrics.cleanup_host_metrics', 'schedule': timedelta(hours=3, minutes=30)}, 'awx.main.tasks.host_metrics.cleanup_host_metrics': {'task': 'awx.main.tasks.host_metrics.cleanup_host_metrics', 'schedule': 12600},
'host_metric_summary_monthly': {'task': 'awx.main.tasks.host_metrics.host_metric_summary_monthly', 'schedule': timedelta(hours=4)}, 'awx.main.tasks.host_metrics.host_metric_summary_monthly': {'task': 'awx.main.tasks.host_metrics.host_metric_summary_monthly', 'schedule': 14400},
'periodic_resource_sync': {'task': 'awx.main.tasks.system.periodic_resource_sync', 'schedule': timedelta(minutes=15)}, 'awx.main.tasks.system.periodic_resource_sync': {'task': 'awx.main.tasks.system.periodic_resource_sync', 'schedule': 900},
'cleanup_and_save_indirect_host_entries_fallback': { 'awx.main.tasks.host_indirect.cleanup_and_save_indirect_host_entries_fallback': {
'task': 'awx.main.tasks.host_indirect.cleanup_and_save_indirect_host_entries_fallback', 'task': 'awx.main.tasks.host_indirect.cleanup_and_save_indirect_host_entries_fallback',
'schedule': timedelta(minutes=60), 'schedule': 3600,
}, },
} }
DISPATCHER_SCHEDULE = {}
for options in CELERYBEAT_SCHEDULE.values():
new_options = options.copy()
task_name = options['task']
# Handle the only one exception case of the heartbeat which has a new implementation
if task_name == 'awx.main.tasks.system.cluster_node_heartbeat':
task_name = 'awx.main.tasks.system.adispatch_cluster_node_heartbeat'
new_options['task'] = task_name
new_options['schedule'] = options['schedule'].total_seconds()
DISPATCHER_SCHEDULE[task_name] = new_options
# Django Caching Configuration # Django Caching Configuration
DJANGO_REDIS_IGNORE_EXCEPTIONS = True DJANGO_REDIS_IGNORE_EXCEPTIONS = True
CACHES = {'default': {'BACKEND': 'awx.main.cache.AWXRedisCache', 'LOCATION': 'unix:///var/run/redis/redis.sock?db=1'}} CACHES = {'default': {'BACKEND': 'awx.main.cache.AWXRedisCache', 'LOCATION': 'unix:///var/run/redis/redis.sock?db=1'}}
@@ -716,7 +700,6 @@ DISABLE_LOCAL_AUTH = False
TOWER_URL_BASE = "https://platformhost" TOWER_URL_BASE = "https://platformhost"
INSIGHTS_URL_BASE = "https://example.org" INSIGHTS_URL_BASE = "https://example.org"
INSIGHTS_OIDC_ENDPOINT = "https://sso.example.org/"
INSIGHTS_AGENT_MIME = 'application/example' INSIGHTS_AGENT_MIME = 'application/example'
# See https://github.com/ansible/awx-facts-playbooks # See https://github.com/ansible/awx-facts-playbooks
INSIGHTS_SYSTEM_ID_FILE = '/etc/redhat-access-insights/machine-id' INSIGHTS_SYSTEM_ID_FILE = '/etc/redhat-access-insights/machine-id'
@@ -1043,12 +1026,14 @@ SPECTACULAR_SETTINGS = {
'SCHEMA_PATH_PREFIX': r'/api/v[0-9]', 'SCHEMA_PATH_PREFIX': r'/api/v[0-9]',
'DEFAULT_GENERATOR_CLASS': 'drf_spectacular.generators.SchemaGenerator', 'DEFAULT_GENERATOR_CLASS': 'drf_spectacular.generators.SchemaGenerator',
'SCHEMA_COERCE_PATH_PK_SUFFIX': True, 'SCHEMA_COERCE_PATH_PK_SUFFIX': True,
'CONTACT': {'email': 'controller-eng@redhat.com'}, 'CONTACT': {'email': 'ansible-community@redhat.com'},
'LICENSE': {'name': 'Apache License'}, 'LICENSE': {'name': 'Apache License'},
'TERMS_OF_SERVICE': 'https://www.google.com/policies/terms/', 'TERMS_OF_SERVICE': 'https://www.google.com/policies/terms/',
# Use our custom schema class that handles swagger_topic and deprecated views # Use our custom schema class that handles swagger_topic and deprecated views
'DEFAULT_SCHEMA_CLASS': 'awx.api.schema.CustomAutoSchema', 'DEFAULT_SCHEMA_CLASS': 'awx.api.schema.CustomAutoSchema',
'COMPONENT_SPLIT_REQUEST': True, 'COMPONENT_SPLIT_REQUEST': True,
# Postprocessing hook to filter CredentialType enum values
'POSTPROCESSING_HOOKS': ['awx.api.schema.filter_credential_type_schema'],
'SWAGGER_UI_SETTINGS': { 'SWAGGER_UI_SETTINGS': {
'deepLinking': True, 'deepLinking': True,
'persistAuthorization': True, 'persistAuthorization': True,
@@ -1149,7 +1134,6 @@ OPA_REQUEST_RETRIES = 2 # The number of retry attempts for connecting to the OP
# feature flags # feature flags
FEATURE_INDIRECT_NODE_COUNTING_ENABLED = False FEATURE_INDIRECT_NODE_COUNTING_ENABLED = False
FEATURE_DISPATCHERD_ENABLED = False
# Dispatcher worker lifetime. If set to None, workers will never be retired # Dispatcher worker lifetime. If set to None, workers will never be retired
# based on age. Note workers will finish their last task before retiring if # based on age. Note workers will finish their last task before retiring if

View File

@@ -69,4 +69,3 @@ AWX_DISABLE_TASK_MANAGERS = False
# ======================!!!!!!! FOR DEVELOPMENT ONLY !!!!!!!================================= # ======================!!!!!!! FOR DEVELOPMENT ONLY !!!!!!!=================================
FEATURE_INDIRECT_NODE_COUNTING_ENABLED = True FEATURE_INDIRECT_NODE_COUNTING_ENABLED = True
FEATURE_DISPATCHERD_ENABLED = True

View File

@@ -14,7 +14,7 @@ $(function() {
$('span.str').each(function() { $('span.str').each(function() {
var s = $(this).html(); var s = $(this).html();
if (s.match(/^\"\/.+\/\"$/) || s.match(/^\"\/.+\/\?.*\"$/)) { if (s.match(/^\"\/.+\/\"$/) || s.match(/^\"\/.+\/\?.*\"$/)) {
$(this).html('"<a href=' + s + '>' + s.replace(/\"/g, '') + '</a>"'); $(this).html('"<a href=' + s + '>' + s.replaceAll('"', '') + '</a>"');
} }
}); });
@@ -27,7 +27,7 @@ $(function() {
}).each(function() { }).each(function() {
$(this).nextUntil('span.pun:contains("]")').filter('span.str').each(function() { $(this).nextUntil('span.pun:contains("]")').filter('span.str').each(function() {
if ($(this).text().match(/^\".+\"$/)) { if ($(this).text().match(/^\".+\"$/)) {
var s = $(this).text().replace(/\"/g, ''); var s = $(this).text().replaceAll('"', '');
$(this).html('"<a href="' + '?host=' + s + '">' + s + '</a>"'); $(this).html('"<a href="' + '?host=' + s + '">' + s + '</a>"');
} }
else if ($(this).text() !== '"') { else if ($(this).text() !== '"') {

View File

@@ -1,4 +1,5 @@
<html> <!DOCTYPE html>
<html lang="en">
<head> <head>
<title>Redirecting</title> <title>Redirecting</title>
<meta http-equiv="refresh" content="0;URL='/#'"/> <meta http-equiv="refresh" content="0;URL='/#'"/>

View File

@@ -1,5 +1,5 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html lang="en">
<head> <head>
<title>On Break...</title> <title>On Break...</title>
<meta http-equiv="refresh" content="2"> <meta http-equiv="refresh" content="2">
@@ -8,7 +8,7 @@
<body> <body>
<div class="container"> <div class="container">
<div class="upper_div"> <div class="upper_div">
<img class="main_image" src="/static/awx-spud-reading.svg"/> <img class="main_image" src="/static/awx-spud-reading.svg" alt="AWX mascot reading a book"/>
<span class="error_number">502</span> <span class="error_number">502</span>
</div> </div>
<div class="message_div"> <div class="message_div">

View File

@@ -1,5 +1,5 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html lang="en">
<head> <head>
<title>On Break...</title> <title>On Break...</title>
<meta http-equiv="refresh" content="2"> <meta http-equiv="refresh" content="2">
@@ -8,7 +8,7 @@
<body> <body>
<div class="container"> <div class="container">
<div class="upper_div"> <div class="upper_div">
<img class="main_image" src="/static/awx-spud-reading.svg"/> <img class="main_image" src="/static/awx-spud-reading.svg" alt="AWX mascot reading a book"/>
<span class="error_number">504</span> <span class="error_number">504</span>
</div> </div>
<div class="message_div"> <div class="message_div">

View File

@@ -28,7 +28,6 @@ body {
.upper_div { .upper_div {
background-color: #F8EBA7; background-color: #F8EBA7;
justify-content: center; justify-content: center;
align-items: center;
text-align: center; text-align: center;
height: 50%; height: 50%;
align-items: flex-end; align-items: flex-end;
@@ -48,7 +47,7 @@ body {
right: 90px; right: 90px;
font-size:200px; font-size:200px;
color: #FDBA48; color: #FDBA48;
font-family: Impact, Haettenschweiler, "Franklin Gothic Bold", Charcoal, "Helvetica Inserat", "Bitstream Vera Sans Bold", "Arial Black", "sans serif"; font-family: Impact, Haettenschweiler, "Franklin Gothic Bold", Charcoal, "Helvetica Inserat", "Bitstream Vera Sans Bold", "Arial Black", sans-serif;
} }
.message_div { .message_div {
@@ -62,7 +61,7 @@ body {
.m1,.m2,.m3 { .m1,.m2,.m3 {
color: #151515; color: #151515;
width: 100%; width: 100%;
font-family: redhat-display-medium; font-family: redhat-display-medium, sans-serif;
} }
.m1 { .m1 {
@@ -78,5 +77,5 @@ body {
.m3 { .m3 {
font-size: 16px; font-size: 16px;
padding-top: 20px; padding-top: 20px;
font-family: redhat-display-regular; font-family: redhat-display-regular, sans-serif;
} }

View File

@@ -18,7 +18,7 @@ div.response-info span.meta {
<div class="container"> <div class="container">
<div class="navbar-header"> <div class="navbar-header">
<a class="navbar-brand" href="/"> <a class="navbar-brand" href="/">
<img class="logo" src="{% static 'media/logo-header.svg' %}"> <img class="logo" src="{% static 'media/logo-header.svg' %}" alt="AWX">
</a> </a>
<a class="navbar-title" href="{{ request.get_full_path }}"> <a class="navbar-title" href="{{ request.get_full_path }}">
<span>&nbsp;&mdash; {{name}}</span> <span>&nbsp;&mdash; {{name}}</span>

View File

@@ -9,7 +9,7 @@
<div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
<div class="row-fluid"> <div class="row-fluid">
<form action="{% url 'api:login' %}" role="form" method="post"> <form action="{% url 'api:login' %}" method="post">
{% csrf_token %} {% csrf_token %}
<input type="hidden" name="next" value={% if request.GET.next %}"{{ request.GET.next }}"{% elif request.POST.next %}"{{ request.POST.next }}"{% else %}"{% url 'api:api_root_view' %}"{% endif %} /> <input type="hidden" name="next" value={% if request.GET.next %}"{{ request.GET.next }}"{% elif request.POST.next %}"{{ request.POST.next }}"{% else %}"{% url 'api:api_root_view' %}"{% endif %} />
<div class="clearfix control-group {% if form.username.errors %}error{% endif %}" <div class="clearfix control-group {% if form.username.errors %}error{% endif %}"

View File

@@ -12,6 +12,15 @@ class ConnectionException(exc.Common):
pass pass
class TokenAuth(requests.auth.AuthBase):
def __init__(self, token):
self.token = token
def __call__(self, request):
request.headers['Authorization'] = 'Bearer {0.token}'.format(self)
return request
def log_elapsed(r, *args, **kwargs): # requests hook to display API elapsed time def log_elapsed(r, *args, **kwargs): # requests hook to display API elapsed time
log.debug('"{0.request.method} {0.url}" elapsed: {0.elapsed}'.format(r)) log.debug('"{0.request.method} {0.url}" elapsed: {0.elapsed}'.format(r))
@@ -37,7 +46,7 @@ class Connection(object):
self.get(config.api_base_path) # this causes a cookie w/ the CSRF token to be set self.get(config.api_base_path) # this causes a cookie w/ the CSRF token to be set
return dict(next=next) return dict(next=next)
def login(self, username=None, password=None, **kwargs): def login(self, username=None, password=None, token=None, **kwargs):
if username and password: if username and password:
_next = kwargs.get('next') _next = kwargs.get('next')
if _next: if _next:
@@ -52,6 +61,8 @@ class Connection(object):
self.uses_session_cookie = True self.uses_session_cookie = True
else: else:
self.session.auth = (username, password) self.session.auth = (username, password)
elif token:
self.session.auth = TokenAuth(token)
else: else:
self.session.auth = None self.session.auth = None

View File

@@ -83,12 +83,23 @@ class CLI(object):
def authenticate(self): def authenticate(self):
"""Configure the current session for authentication. """Configure the current session for authentication.
Uses Basic authentication when AWXKIT_FORCE_BASIC_AUTH environment variable Authentication priority:
is set to true, otherwise defaults to session-based authentication. 1. Token authentication (if --conf.token provided)
2. Basic authentication (if AWXKIT_FORCE_BASIC_AUTH=true)
3. Session-based authentication (default)
For AAP Gateway environments, set AWXKIT_FORCE_BASIC_AUTH=true to bypass For AAP Gateway environments, set AWXKIT_FORCE_BASIC_AUTH=true to bypass
session login restrictions. session login restrictions when using username/password.
""" """
# Token authentication (if token is provided)
token = self.get_config('token')
if token:
config.use_sessions = False
self.root.connection.login(None, None, token=token)
return
# Check if Basic auth is forced via environment variable # Check if Basic auth is forced via environment variable
if config.get('force_basic_auth', False): if config.get('force_basic_auth', False):
config.use_sessions = False config.use_sessions = False
@@ -251,7 +262,13 @@ class CLI(object):
if self.resource != 'settings': if self.resource != 'settings':
for method in ('list', 'modify', 'create'): for method in ('list', 'modify', 'create'):
if method in parser.parser.choices: if method in parser.parser.choices:
parser.build_query_arguments(method, 'GET' if method == 'list' else 'POST') if method == 'list':
http_method = 'GET'
elif method == 'modify' and 'PUT' in parser.options:
http_method = 'PUT'
else:
http_method = 'POST'
parser.build_query_arguments(method, http_method)
if from_sphinx: if from_sphinx:
parsed, extra = self.parser.parse_known_args(self.argv) parsed, extra = self.parser.parse_known_args(self.argv)
else: else:

View File

@@ -59,6 +59,12 @@ def add_authentication_arguments(parser, env):
default=env.get('CONTROLLER_PASSWORD', env.get('TOWER_PASSWORD', config_password)), default=env.get('CONTROLLER_PASSWORD', env.get('TOWER_PASSWORD', config_password)),
metavar='TEXT', metavar='TEXT',
) )
auth.add_argument(
'--conf.token',
default=env.get('CONTROLLER_OAUTH_TOKEN', env.get('TOWER_OAUTH_TOKEN', None)),
metavar='TEXT',
help='OAuth2 token for authentication (takes precedence over username/password)',
)
auth.add_argument( auth.add_argument(
'-k', '-k',

View File

@@ -102,6 +102,18 @@ class ResourceOptionsParser(object):
if '299' in warning and 'deprecated' in warning: if '299' in warning and 'deprecated' in warning:
self.deprecated = True self.deprecated = True
self.allowed_options = options.headers.get('Allow', '').split(', ') self.allowed_options = options.headers.get('Allow', '').split(', ')
# If the user can PUT on the detail endpoint but doesn't have
# POST on the list endpoint, use the detail endpoint's
# action schema so that 'modify' fields are populated.
if 'POST' not in self.options and 'PUT' in self.allowed_options:
try:
detail_actions = options.json().get('actions', {})
except Exception:
detail_actions = {}
if 'PUT' in detail_actions:
self.options['PUT'] = detail_actions['PUT']
elif 'GET' in detail_actions:
self.options['PUT'] = detail_actions['GET']
def build_list_actions(self): def build_list_actions(self):
action_map = { action_map = {
@@ -109,6 +121,10 @@ class ResourceOptionsParser(object):
'POST': 'create', 'POST': 'create',
} }
for method, action in self.options.items(): for method, action in self.options.items():
# Skip 'PUT', which may be added by get_allowed_options
# and is handled separately by build_detail_actions
if method not in action_map:
continue
method = action_map[method] method = action_map[method]
parser = self.parser.add_parser(method, help='') parser = self.parser.add_parser(method, help='')
if method == 'list': if method == 'list':

View File

@@ -44,6 +44,48 @@ def setup_session_auth(cli_args: Optional[List[str]] = None) -> Tuple[CLI, Mock,
return cli, mock_root, mock_load_session return cli, mock_root, mock_load_session
def setup_token_auth(cli_args: Optional[List[str]] = None) -> Tuple[CLI, Mock, Mock]:
"""Set up CLI with mocked connection for Token auth testing"""
cli = CLI()
cli.parse_args(cli_args or ['awx', '--conf.token', 'test-token-abc123'])
mock_root = Mock()
mock_connection = Mock()
mock_root.connection = mock_connection
cli.root = mock_root
return cli, mock_root, mock_connection
def test_token_auth_preserved(monkeypatch):
"""
REGRESSION TEST: Token authentication must still work (existed in 4.6.12)
This test documents the customer's working scenario from 4.6.12:
awx login --conf.host URL --conf.username USER --conf.password PASS
# Returns: {"token": "E*******J"}
awx --conf.host URL --conf.token E*******J job_templates launch ...
# This WORKED in 4.6.12
BREAKING CHANGE: Version 4.6.21 removed token authentication entirely,
causing customer to report: "neither token no username/password are working"
This test will FAIL with current code and PASS once fixed.
"""
cli, mock_root, mock_connection = setup_token_auth(['awx', '--conf.host', 'https://aap-sbx.testbank.com', '--conf.token', 'E1234567890J'])
monkeypatch.setattr(config, 'force_basic_auth', False)
# Execute authentication
cli.authenticate()
# Token auth should call login with token parameter
mock_connection.login.assert_called_once_with(None, None, token='E1234567890J')
# Should NOT use sessions when token is provided
assert not config.use_sessions
def test_basic_auth_enabled(monkeypatch): def test_basic_auth_enabled(monkeypatch):
"""Test that AWXKIT_FORCE_BASIC_AUTH=true enables Basic authentication""" """Test that AWXKIT_FORCE_BASIC_AUTH=true enables Basic authentication"""
cli, mock_root, mock_connection = setup_basic_auth() cli, mock_root, mock_connection = setup_basic_auth()

Some files were not shown because too many files have changed in this diff Show More