mirror of
https://github.com/ansible/awx.git
synced 2026-04-27 12:45:24 -02:30
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed1b5c5519 | ||
|
|
d2e51c4124 |
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -24,7 +24,7 @@ in as the first entry for your PR title.
|
||||
|
||||
|
||||
|
||||
##### STEPS TO REPRODUCE AND EXTRA INFO
|
||||
##### ADDITIONAL INFORMATION
|
||||
<!---
|
||||
Include additional information to help people understand the change here.
|
||||
For bugs that don't have a linked bug report, a step-by-step reproduction
|
||||
|
||||
55
.github/workflows/_repo-owns-branch.yml
vendored
55
.github/workflows/_repo-owns-branch.yml
vendored
@@ -1,55 +0,0 @@
|
||||
---
|
||||
name: Repo Owns Branch
|
||||
|
||||
# Reusable workflow that determines whether the current repository
|
||||
# owns the current branch for push operations.
|
||||
#
|
||||
# Ownership rules:
|
||||
# - ansible/awx owns: devel, feature_*
|
||||
# - ansible/tower owns: stable-*, release_*
|
||||
# - workflow_dispatch is always allowed
|
||||
#
|
||||
# All other repo/branch combinations are skipped.
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
outputs:
|
||||
should_run:
|
||||
description: Whether this repo owns the current branch
|
||||
value: ${{ jobs.check.outputs.should_run }}
|
||||
|
||||
jobs:
|
||||
check:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
should_run: ${{ steps.check.outputs.should_run }}
|
||||
steps:
|
||||
- name: Check branch ownership
|
||||
id: check
|
||||
run: |
|
||||
REPO="${{ github.repository }}"
|
||||
BRANCH="${{ github.ref_name }}"
|
||||
EVENT="${{ github.event_name }}"
|
||||
|
||||
if [[ "$EVENT" == "workflow_dispatch" ]]; then
|
||||
echo "should_run=true" >> $GITHUB_OUTPUT
|
||||
echo "Manual trigger — allowed"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# ansible/awx owns devel and feature_* branches
|
||||
if [[ "$REPO" == "ansible/awx" ]] && [[ "$BRANCH" == "devel" || "$BRANCH" == feature_* ]]; then
|
||||
echo "should_run=true" >> $GITHUB_OUTPUT
|
||||
echo "Repository '$REPO' owns branch '$BRANCH'"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# ansible/tower owns stable-* and release_* branches
|
||||
if [[ "$REPO" == "ansible/tower" ]] && [[ "$BRANCH" == stable-* || "$BRANCH" == release_* ]]; then
|
||||
echo "should_run=true" >> $GITHUB_OUTPUT
|
||||
echo "Repository '$REPO' owns branch '$BRANCH'"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_run=false" >> $GITHUB_OUTPUT
|
||||
echo "Repository '$REPO' does not own branch '$BRANCH' — skipping"
|
||||
48
.github/workflows/ci.yml
vendored
48
.github/workflows/ci.yml
vendored
@@ -4,46 +4,14 @@ env:
|
||||
LC_ALL: "C.UTF-8" # prevent ERROR: Ansible could not initialize the preferred locale: unsupported locale setting
|
||||
CI_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
DEV_DOCKER_OWNER: ${{ github.repository_owner }}
|
||||
COMPOSE_TAG: ${{ github.base_ref || github.ref_name || 'devel' }}
|
||||
COMPOSE_TAG: ${{ github.base_ref || 'devel' }}
|
||||
UPSTREAM_REPOSITORY_ID: 91594105
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- devel # needed to publish code coverage post-merge
|
||||
schedule:
|
||||
- cron: '0 12,18 * * 1-5'
|
||||
workflow_dispatch: {}
|
||||
jobs:
|
||||
trigger-release-branches:
|
||||
name: "Dispatch CI to release branches"
|
||||
if: github.event_name == 'schedule'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Trigger CI on release_4.6
|
||||
id: dispatch_release_46
|
||||
continue-on-error: true
|
||||
run: gh workflow run ci.yml --ref release_4.6
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_REPO: ${{ github.repository }}
|
||||
- name: Trigger CI on stable-2.6
|
||||
id: dispatch_stable_26
|
||||
continue-on-error: true
|
||||
run: gh workflow run ci.yml --ref stable-2.6
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_REPO: ${{ github.repository }}
|
||||
- name: Check dispatch results
|
||||
if: steps.dispatch_release_46.outcome == 'failure' || steps.dispatch_stable_26.outcome == 'failure'
|
||||
run: |
|
||||
echo "One or more dispatches failed:"
|
||||
echo " release_4.6: ${{ steps.dispatch_release_46.outcome }}"
|
||||
echo " stable-2.6: ${{ steps.dispatch_stable_26.outcome }}"
|
||||
exit 1
|
||||
|
||||
common-tests:
|
||||
name: ${{ matrix.tests.name }}
|
||||
runs-on: ubuntu-latest
|
||||
@@ -94,11 +62,7 @@ jobs:
|
||||
run: |
|
||||
if [ -f "reports/coverage.xml" ]; then
|
||||
sed -i '2i<!-- PR ${{ github.event.pull_request.number }} -->' reports/coverage.xml
|
||||
echo "Injected PR number ${{ github.event.pull_request.number }} into reports/coverage.xml"
|
||||
fi
|
||||
if [ -f "awxkit/coverage.xml" ]; then
|
||||
sed -i '2i<!-- PR ${{ github.event.pull_request.number }} -->' awxkit/coverage.xml
|
||||
echo "Injected PR number ${{ github.event.pull_request.number }} into awxkit/coverage.xml"
|
||||
echo "Injected PR number ${{ github.event.pull_request.number }} into coverage.xml"
|
||||
fi
|
||||
|
||||
- name: Upload test coverage to Codecov
|
||||
@@ -145,9 +109,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.tests.name }}-artifacts
|
||||
path: |
|
||||
reports/coverage.xml
|
||||
awxkit/coverage.xml
|
||||
path: reports/coverage.xml
|
||||
retention-days: 5
|
||||
|
||||
- name: >-
|
||||
@@ -160,7 +122,7 @@ jobs:
|
||||
&& github.event_name == 'push'
|
||||
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id
|
||||
&& github.ref_name == github.event.repository.default_branch
|
||||
uses: ansible/gh-action-record-test-results@3784db66a1b7fb3809999a7251c8a7203a7ffbe8
|
||||
uses: ansible/gh-action-record-test-results@cd5956ead39ec66351d0779470c8cff9638dd2b8
|
||||
with:
|
||||
aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}
|
||||
http-auth-password: >-
|
||||
@@ -334,7 +296,7 @@ jobs:
|
||||
&& github.event_name == 'push'
|
||||
&& env.UPSTREAM_REPOSITORY_ID == github.repository_id
|
||||
&& github.ref_name == github.event.repository.default_branch
|
||||
uses: ansible/gh-action-record-test-results@3784db66a1b7fb3809999a7251c8a7203a7ffbe8
|
||||
uses: ansible/gh-action-record-test-results@cd5956ead39ec66351d0779470c8cff9638dd2b8
|
||||
with:
|
||||
aggregation-server-url: ${{ vars.PDE_ORG_RESULTS_AGGREGATOR_UPLOAD_URL }}
|
||||
http-auth-password: >-
|
||||
|
||||
11
.github/workflows/devel_images.yml
vendored
11
.github/workflows/devel_images.yml
vendored
@@ -12,12 +12,7 @@ on:
|
||||
- feature_*
|
||||
- stable-*
|
||||
jobs:
|
||||
check-ownership:
|
||||
uses: ./.github/workflows/_repo-owns-branch.yml
|
||||
|
||||
push-development-images:
|
||||
needs: check-ownership
|
||||
if: needs.check-ownership.outputs.should_run == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
permissions:
|
||||
@@ -35,6 +30,12 @@ jobs:
|
||||
make-target: awx-kube-buildx
|
||||
steps:
|
||||
|
||||
- name: Skipping build of awx image for non-awx repository
|
||||
run: |
|
||||
echo "Skipping build of awx image for non-awx repository"
|
||||
exit 0
|
||||
if: matrix.build-targets.image-name == 'awx' && !endsWith(github.repository, '/awx')
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
show-progress: false
|
||||
|
||||
7
.github/workflows/spec-sync-on-merge.yml
vendored
7
.github/workflows/spec-sync-on-merge.yml
vendored
@@ -16,16 +16,9 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- devel
|
||||
- 'stable-2.[6-9]'
|
||||
- 'stable-2.[1-9][0-9]'
|
||||
workflow_dispatch: # Allow manual triggering for testing
|
||||
jobs:
|
||||
check-ownership:
|
||||
uses: ./.github/workflows/_repo-owns-branch.yml
|
||||
|
||||
sync-openapi-spec:
|
||||
needs: check-ownership
|
||||
if: needs.check-ownership.outputs.should_run == 'true'
|
||||
name: Sync OpenAPI spec to central repo
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
|
||||
5
.github/workflows/upload_schema.yml
vendored
5
.github/workflows/upload_schema.yml
vendored
@@ -13,12 +13,7 @@ on:
|
||||
- feature_**
|
||||
- stable-**
|
||||
jobs:
|
||||
check-ownership:
|
||||
uses: ./.github/workflows/_repo-owns-branch.yml
|
||||
|
||||
push:
|
||||
needs: check-ownership
|
||||
if: needs.check-ownership.outputs.should_run == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 60
|
||||
permissions:
|
||||
|
||||
22
Makefile
22
Makefile
@@ -106,12 +106,6 @@ else
|
||||
DOCKER_KUBE_CACHE_FLAG=$(DOCKER_CACHE)
|
||||
endif
|
||||
|
||||
# AWX TUI variables
|
||||
AWX_HOST ?= https://localhost:8043
|
||||
AWX_USER ?= admin
|
||||
AWX_PASSWORD ?= $$(awk -F"'" '/^admin_password:/{print $$2}' tools/docker-compose/_sources/secrets/admin_password.yml 2>/dev/null || echo "admin")
|
||||
AWX_VERIFY_SSL ?= false
|
||||
|
||||
.PHONY: awx-link clean clean-tmp clean-venv requirements requirements_dev \
|
||||
update_requirements upgrade_requirements update_requirements_dev \
|
||||
docker_update_requirements docker_upgrade_requirements docker_update_requirements_dev \
|
||||
@@ -577,20 +571,6 @@ docker-compose-runtest: awx/projects docker-compose-sources
|
||||
docker-compose-build-schema: awx/projects docker-compose-sources
|
||||
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml run --rm --service-ports --no-deps awx_1 make genschema
|
||||
|
||||
awx-tui:
|
||||
@if ! command -v awx-tui > /dev/null 2>&1; then \
|
||||
$(PYTHON) -m pip install awx-tui; \
|
||||
fi
|
||||
@if [ -f "$(HOME)/.config/awx-tui/config.yaml" ]; then \
|
||||
$(PYTHON) -m awx_tui.main; \
|
||||
else \
|
||||
AWX_HOST=$(AWX_HOST) \
|
||||
AWX_USER=$(AWX_USER) \
|
||||
AWX_PASSWORD=$(AWX_PASSWORD) \
|
||||
AWX_VERIFY_SSL=$(AWX_VERIFY_SSL) \
|
||||
$(PYTHON) -m awx_tui.main --host $(AWX_HOST); \
|
||||
fi
|
||||
|
||||
SCHEMA_DIFF_BASE_FOLDER ?= awx
|
||||
SCHEMA_DIFF_BASE_BRANCH ?= devel
|
||||
detect-schema-change: genschema
|
||||
@@ -601,7 +581,7 @@ detect-schema-change: genschema
|
||||
|
||||
validate-openapi-schema: genschema
|
||||
@echo "Validating OpenAPI schema from schema.json..."
|
||||
@python3 -c "from openapi_spec_validator import validate; import json; spec = json.load(open('schema.json')); validate(spec); print('✓ Schema is valid')"
|
||||
@python3 -c "from openapi_spec_validator import validate; import json; spec = json.load(open('schema.json')); validate(spec); print('✓ OpenAPI Schema is valid!')"
|
||||
|
||||
docker-compose-clean: awx/projects
|
||||
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml rm -sf
|
||||
|
||||
@@ -131,14 +131,8 @@ class LoggedLoginView(auth_views.LoginView):
|
||||
|
||||
|
||||
class LoggedLogoutView(auth_views.LogoutView):
|
||||
# Override http_method_names to allow GET requests (Django 5.2+ defaults to POST only)
|
||||
http_method_names = ["get", "post", "options"]
|
||||
success_url_allowed_hosts = set(settings.LOGOUT_ALLOWED_HOSTS.split(",")) if settings.LOGOUT_ALLOWED_HOSTS else set()
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
"""Handle GET requests for logout (for backward compatibility)."""
|
||||
return self.post(request, *args, **kwargs)
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if is_proxied_request():
|
||||
# 1) We intentionally don't obey ?next= here, just always redirect to platform login
|
||||
@@ -272,10 +266,7 @@ class APIView(views.APIView):
|
||||
response = self.handle_exception(self.__init_request_error__)
|
||||
if response.status_code == 401:
|
||||
if response.data and 'detail' in response.data:
|
||||
if getattr(settings, 'RESOURCE_SERVER__URL', None):
|
||||
response.data['detail'] += _(' Direct access is not allowed, authenticate via the platform gateway.')
|
||||
else:
|
||||
response.data['detail'] += _(' To establish a login session, visit') + ' /api/login/.'
|
||||
response.data['detail'] += _(' To establish a login session, visit') + ' /api/login/.'
|
||||
logger.info(status_msg)
|
||||
else:
|
||||
logger.warning(status_msg)
|
||||
|
||||
@@ -122,6 +122,7 @@ from awx.main.scheduler.task_manager_models import TaskManagerModels
|
||||
from awx.main.redact import UriCleaner, REPLACE_STR
|
||||
from awx.main.signals import update_inventory_computed_fields
|
||||
|
||||
|
||||
from awx.main.validators import vars_validate_or_raise
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
@@ -2078,17 +2079,9 @@ class BulkHostCreateSerializer(serializers.Serializer):
|
||||
if request and not request.user.is_superuser:
|
||||
if request.user not in inv.admin_role:
|
||||
raise serializers.ValidationError(_(f'Inventory with id {inv.id} not found or lack permissions to add hosts.'))
|
||||
|
||||
# Performance optimization (AAP-67978): Instead of loading ALL host names from
|
||||
# the inventory, only check if the specific new names already exist in the database.
|
||||
current_hostnames = set(inv.hosts.values_list('name', flat=True))
|
||||
new_names = [host['name'] for host in attrs['hosts']]
|
||||
|
||||
new_name_counts = Counter(new_names)
|
||||
duplicates_in_new = [name for name, count in new_name_counts.items() if count > 1]
|
||||
unique_new_names = list(new_name_counts.keys())
|
||||
existing_duplicates = list(Host.objects.filter(inventory=inv, name__in=unique_new_names).values_list('name', flat=True))
|
||||
duplicate_new_names = list(set(duplicates_in_new + existing_duplicates))
|
||||
|
||||
duplicate_new_names = [n for n in new_names if n in current_hostnames or new_names.count(n) > 1]
|
||||
if duplicate_new_names:
|
||||
raise serializers.ValidationError(_(f'Hostnames must be unique in an inventory. Duplicates found: {duplicate_new_names}'))
|
||||
|
||||
@@ -2939,19 +2932,6 @@ class CredentialTypeSerializer(BaseSerializer):
|
||||
field['label'] = _(field['label'])
|
||||
if 'help_text' in field:
|
||||
field['help_text'] = _(field['help_text'])
|
||||
|
||||
# Deep copy inputs to avoid modifying the original model data
|
||||
inputs = value.get('inputs')
|
||||
if not isinstance(inputs, dict):
|
||||
inputs = {}
|
||||
value['inputs'] = copy.deepcopy(inputs)
|
||||
fields = value['inputs'].get('fields', [])
|
||||
if not isinstance(fields, list):
|
||||
fields = []
|
||||
|
||||
# Normalize fields and filter out internal fields
|
||||
value['inputs']['fields'] = [f for f in fields if not f.get('internal')]
|
||||
|
||||
return value
|
||||
|
||||
def filter_field_metadata(self, fields, method):
|
||||
@@ -4142,28 +4122,9 @@ class LaunchConfigurationBaseSerializer(BaseSerializer):
|
||||
attrs['extra_data'][key] = db_extra_data[key]
|
||||
|
||||
# Build unsaved version of this config, use it to detect prompts errors
|
||||
# Capture keys before _build_mock_obj pops pseudo-fields from attrs
|
||||
incoming_attr_keys = set(attrs.keys())
|
||||
mock_obj = self._build_mock_obj(attrs)
|
||||
ask_mapping_keys = set(ujt.get_ask_mapping().keys())
|
||||
requested_prompt_fields = incoming_attr_keys & ask_mapping_keys
|
||||
if 'extra_data' in incoming_attr_keys:
|
||||
requested_prompt_fields.add('extra_vars')
|
||||
requested_prompt_fields.add('survey_passwords')
|
||||
|
||||
# prompts_dict() pulls persisted M2M state (labels, credentials,
|
||||
# instance_groups) via the instance pk. Only re-validate the full prompt
|
||||
# state when the caller is switching the underlying template; otherwise
|
||||
# restrict validation to the fields the request explicitly provided.
|
||||
if 'unified_job_template' in attrs:
|
||||
prompts_to_validate = mock_obj.prompts_dict()
|
||||
elif requested_prompt_fields:
|
||||
prompts_to_validate = {k: v for k, v in mock_obj.prompts_dict().items() if k in requested_prompt_fields}
|
||||
else:
|
||||
prompts_to_validate = None
|
||||
|
||||
if prompts_to_validate is not None:
|
||||
accepted, rejected, errors = ujt._accept_or_ignore_job_kwargs(_exclude_errors=self.exclude_errors, **prompts_to_validate)
|
||||
if set(list(ujt.get_ask_mapping().keys()) + ['extra_data']) & set(attrs.keys()):
|
||||
accepted, rejected, errors = ujt._accept_or_ignore_job_kwargs(_exclude_errors=self.exclude_errors, **mock_obj.prompts_dict())
|
||||
else:
|
||||
# Only perform validation of prompts if prompts fields are provided
|
||||
errors = {}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
---
|
||||
collections:
|
||||
- name: ansible.receptor
|
||||
version: 2.0.8
|
||||
version: 2.0.6
|
||||
|
||||
@@ -14,7 +14,6 @@ import sys
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from collections import OrderedDict
|
||||
from jwt import decode as _jwt_decode
|
||||
|
||||
from urllib3.exceptions import ConnectTimeoutError
|
||||
|
||||
@@ -56,16 +55,10 @@ from wsgiref.util import FileWrapper
|
||||
from drf_spectacular.utils import extend_schema_view, extend_schema
|
||||
|
||||
# django-ansible-base
|
||||
from ansible_base.lib.utils.requests import get_remote_hosts
|
||||
from ansible_base.rbac.models import RoleEvaluation
|
||||
from ansible_base.lib.utils.schema import extend_schema_if_available
|
||||
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
|
||||
|
||||
# flags
|
||||
from flags.state import flag_enabled
|
||||
|
||||
# AWX
|
||||
from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims
|
||||
from awx.main.tasks.system import send_notifications, update_inventory_computed_fields
|
||||
from awx.main.access import get_user_queryset
|
||||
from awx.api.generics import (
|
||||
@@ -104,6 +97,7 @@ from awx.main.utils import (
|
||||
from awx.main.utils.encryption import encrypt_value
|
||||
from awx.main.utils.filters import SmartFilter
|
||||
from awx.main.utils.plugins import compute_cloud_inventory_sources
|
||||
from awx.main.utils.proxy import get_first_remote_host_from_headers
|
||||
from awx.main.utils.common import memoize
|
||||
from awx.main.redact import UriCleaner
|
||||
from awx.api.permissions import (
|
||||
@@ -1601,175 +1595,7 @@ class CredentialCopy(CopyAPIView):
|
||||
resource_purpose = 'copy of a credential'
|
||||
|
||||
|
||||
class OIDCCredentialTestMixin:
|
||||
"""
|
||||
Mixin to add OIDC workload identity token support to credential test endpoints.
|
||||
|
||||
This mixin provides methods to handle OIDC-enabled external credentials that use
|
||||
workload identity tokens for authentication.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_workload_identity_token(job_template: models.JobTemplate, jwt_aud: str) -> str:
|
||||
"""Generate a workload identity token for a job template.
|
||||
|
||||
Args:
|
||||
job_template: The JobTemplate instance to generate claims for
|
||||
jwt_aud: The JWT audience claim value
|
||||
|
||||
Returns:
|
||||
str: The generated JWT token
|
||||
"""
|
||||
claims = {
|
||||
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: job_template.organization.name,
|
||||
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: job_template.organization.id,
|
||||
AutomationControllerJobScope.CLAIM_PROJECT_NAME: job_template.project.name,
|
||||
AutomationControllerJobScope.CLAIM_PROJECT_ID: job_template.project.id,
|
||||
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: job_template.name,
|
||||
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: job_template.id,
|
||||
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: job_template.playbook,
|
||||
}
|
||||
return retrieve_workload_identity_jwt_with_claims(
|
||||
claims=claims,
|
||||
audience=jwt_aud,
|
||||
scope=AutomationControllerJobScope.name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _decode_jwt_payload_for_display(jwt_token):
|
||||
"""Decode JWT payload for display purposes only (signature not verified).
|
||||
|
||||
This is safe because the JWT was just created by AWX and is only decoded
|
||||
to show the user what claims are being sent to the external system.
|
||||
The external system will perform proper signature verification.
|
||||
|
||||
Args:
|
||||
jwt_token: The JWT token to decode
|
||||
|
||||
Returns:
|
||||
dict: The decoded JWT payload
|
||||
"""
|
||||
return _jwt_decode(jwt_token, algorithms=["RS256"], options={"verify_signature": False}) # NOSONAR python:S5659
|
||||
|
||||
def _has_workload_identity_token(self, credential_type_inputs):
|
||||
"""Check if credential type has an internal workload_identity_token field.
|
||||
|
||||
Args:
|
||||
credential_type_inputs: The inputs dict from a credential type
|
||||
|
||||
Returns:
|
||||
bool: True if the credential type has a workload_identity_token field marked as internal
|
||||
"""
|
||||
fields = credential_type_inputs.get('fields', []) if isinstance(credential_type_inputs, dict) else []
|
||||
return any(field.get('internal') and field.get('id') == 'workload_identity_token' for field in fields)
|
||||
|
||||
def _validate_and_get_job_template(self, job_template_id):
|
||||
"""Validate job template ID and return the JobTemplate instance.
|
||||
|
||||
Args:
|
||||
job_template_id: The job template ID from metadata
|
||||
|
||||
Returns:
|
||||
JobTemplate instance
|
||||
|
||||
Raises:
|
||||
ParseError: If job_template_id is invalid or not found
|
||||
"""
|
||||
if job_template_id is None:
|
||||
raise ParseError(_('Job template ID is required.'))
|
||||
|
||||
try:
|
||||
return models.JobTemplate.objects.get(id=int(job_template_id))
|
||||
except ValueError:
|
||||
raise ParseError(_('Job template ID must be an integer.'))
|
||||
except models.JobTemplate.DoesNotExist:
|
||||
raise ParseError(_('Job template with ID %(id)s does not exist.') % {'id': job_template_id})
|
||||
|
||||
def _handle_oidc_credential_test(self, backend_kwargs):
|
||||
"""
|
||||
Handle OIDC workload identity token generation for external credential test endpoints.
|
||||
|
||||
This method should only be called when FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is enabled
|
||||
and the credential type has a workload_identity_token field.
|
||||
|
||||
Args:
|
||||
backend_kwargs: The kwargs dict to pass to the backend (will be modified in place)
|
||||
|
||||
Returns:
|
||||
dict: Response body containing details with the sent JWT payload
|
||||
|
||||
Raises:
|
||||
PermissionDenied: If user lacks access to the job template (re-raised for 403 response)
|
||||
|
||||
All other exceptions are caught and converted to 400 responses with error details.
|
||||
|
||||
Modifies backend_kwargs in place to add workload_identity_token.
|
||||
"""
|
||||
# Validate job template
|
||||
job_template_id = backend_kwargs.pop('job_template_id', None)
|
||||
job_template = self._validate_and_get_job_template(job_template_id)
|
||||
|
||||
# Check user access
|
||||
if not self.request.user.can_access(models.JobTemplate, 'start', job_template):
|
||||
raise PermissionDenied(_('You do not have access to job template with id: %(id)s.') % {'id': job_template.id})
|
||||
|
||||
# Generate workload identity token
|
||||
jwt_token = self._get_workload_identity_token(job_template, backend_kwargs.pop('jwt_aud', None))
|
||||
backend_kwargs['workload_identity_token'] = jwt_token
|
||||
|
||||
return {'details': {'sent_jwt_payload': self._decode_jwt_payload_for_display(jwt_token)}}
|
||||
|
||||
def _call_backend_with_error_handling(self, plugin, backend_kwargs, response_body):
|
||||
"""Call credential backend and handle errors."""
|
||||
try:
|
||||
with set_environ(**settings.AWX_TASK_ENV):
|
||||
plugin.backend(**backend_kwargs)
|
||||
return Response(response_body, status=status.HTTP_202_ACCEPTED)
|
||||
except requests.exceptions.HTTPError as exc:
|
||||
message = self._extract_http_error_message(exc)
|
||||
self._add_error_to_response(response_body, message)
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as exc:
|
||||
message = self._extract_generic_error_message(exc)
|
||||
self._add_error_to_response(response_body, message)
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@staticmethod
|
||||
def _extract_http_error_message(exc):
|
||||
"""Extract error message from HTTPError, checking response JSON and text."""
|
||||
message = str(exc)
|
||||
if not hasattr(exc, 'response') or exc.response is None:
|
||||
return message
|
||||
|
||||
try:
|
||||
error_data = exc.response.json()
|
||||
if 'errors' in error_data and error_data['errors']:
|
||||
return ', '.join(error_data['errors'])
|
||||
if 'error' in error_data:
|
||||
return error_data['error']
|
||||
except (ValueError, KeyError):
|
||||
if exc.response.text:
|
||||
return exc.response.text
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _extract_generic_error_message(exc):
|
||||
"""Extract error message from exception, handling ConnectTimeoutError specially."""
|
||||
message = str(exc) if str(exc) else exc.__class__.__name__
|
||||
for arg in getattr(exc, 'args', []):
|
||||
if isinstance(getattr(arg, 'reason', None), ConnectTimeoutError):
|
||||
return str(arg.reason)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _add_error_to_response(response_body, message):
|
||||
"""Add error message to both 'detail' and 'details.error_message' fields."""
|
||||
response_body['detail'] = message
|
||||
if 'details' in response_body:
|
||||
response_body['details']['error_message'] = message
|
||||
|
||||
|
||||
class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
|
||||
class CredentialExternalTest(SubDetailAPIView):
|
||||
"""
|
||||
Test updates to the input values and metadata of an external credential
|
||||
before saving them.
|
||||
@@ -1789,8 +1615,6 @@ class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
|
||||
It does not support standard credential types such as Machine, SCM, and Cloud."""})
|
||||
def post(self, request, *args, **kwargs):
|
||||
obj = self.get_object()
|
||||
if obj.credential_type.kind != 'external':
|
||||
raise ParseError(_('Credential is not testable.'))
|
||||
backend_kwargs = {}
|
||||
for field_name, value in obj.inputs.items():
|
||||
backend_kwargs[field_name] = obj.get_input(field_name)
|
||||
@@ -1798,22 +1622,23 @@ class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
|
||||
if value != '$encrypted$':
|
||||
backend_kwargs[field_name] = value
|
||||
backend_kwargs.update(request.data.get('metadata', {}))
|
||||
|
||||
# Handle OIDC workload identity token generation if enabled
|
||||
response_body = {}
|
||||
if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.credential_type.inputs):
|
||||
try:
|
||||
oidc_response_body = self._handle_oidc_credential_test(backend_kwargs)
|
||||
response_body.update(oidc_response_body)
|
||||
except PermissionDenied:
|
||||
raise
|
||||
except Exception as exc:
|
||||
error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc)
|
||||
response_body['detail'] = error_message
|
||||
response_body['details'] = {'error_message': error_message}
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return self._call_backend_with_error_handling(obj.credential_type.plugin, backend_kwargs, response_body)
|
||||
try:
|
||||
with set_environ(**settings.AWX_TASK_ENV):
|
||||
obj.credential_type.plugin.backend(**backend_kwargs)
|
||||
return Response({}, status=status.HTTP_202_ACCEPTED)
|
||||
except requests.exceptions.HTTPError:
|
||||
message = """Test operation is not supported for credential type {}.
|
||||
This endpoint only supports credentials that connect to
|
||||
external secret management systems such as CyberArk, HashiCorp
|
||||
Vault, or cloud-based secret managers.""".format(obj.credential_type.kind)
|
||||
return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as exc:
|
||||
message = exc.__class__.__name__
|
||||
exc_args = getattr(exc, 'args', [])
|
||||
for a in exc_args:
|
||||
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
|
||||
message = str(a.reason)
|
||||
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
class CredentialInputSourceDetail(RetrieveUpdateDestroyAPIView):
|
||||
@@ -1843,7 +1668,7 @@ class CredentialInputSourceSubList(SubListCreateAPIView):
|
||||
parent_key = 'target_credential'
|
||||
|
||||
|
||||
class CredentialTypeExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
|
||||
class CredentialTypeExternalTest(SubDetailAPIView):
|
||||
"""
|
||||
Test a complete set of input values for an external credential before
|
||||
saving it.
|
||||
@@ -1858,26 +1683,21 @@ class CredentialTypeExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
|
||||
@extend_schema_if_available(extensions={"x-ai-description": "Test a complete set of input values for an external credential"})
|
||||
def post(self, request, *args, **kwargs):
|
||||
obj = self.get_object()
|
||||
if obj.kind != 'external':
|
||||
raise ParseError(_('Credential type is not testable.'))
|
||||
backend_kwargs = request.data.get('inputs', {})
|
||||
backend_kwargs.update(request.data.get('metadata', {}))
|
||||
|
||||
# Handle OIDC workload identity token generation if enabled
|
||||
response_body = {}
|
||||
if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.inputs):
|
||||
try:
|
||||
oidc_response_body = self._handle_oidc_credential_test(backend_kwargs)
|
||||
response_body.update(oidc_response_body)
|
||||
except PermissionDenied:
|
||||
raise
|
||||
except Exception as exc:
|
||||
error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc)
|
||||
response_body['detail'] = error_message
|
||||
response_body['details'] = {'error_message': error_message}
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return self._call_backend_with_error_handling(obj.plugin, backend_kwargs, response_body)
|
||||
try:
|
||||
obj.plugin.backend(**backend_kwargs)
|
||||
return Response({}, status=status.HTTP_202_ACCEPTED)
|
||||
except requests.exceptions.HTTPError as exc:
|
||||
message = 'HTTP {}'.format(exc.response.status_code)
|
||||
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as exc:
|
||||
message = exc.__class__.__name__
|
||||
args_exc = getattr(exc, 'args', [])
|
||||
for a in args_exc:
|
||||
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
|
||||
message = str(a.reason)
|
||||
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
class HostRelatedSearchMixin(object):
|
||||
@@ -3057,7 +2877,8 @@ class JobTemplateCallback(GenericAPIView):
|
||||
host for the current request.
|
||||
"""
|
||||
# Find the list of remote host names/IPs to check.
|
||||
remote_hosts = set(get_remote_hosts(self.request))
|
||||
# Only consider the first entry from each header (for comma-separated values like X-Forwarded-For)
|
||||
remote_hosts = get_first_remote_host_from_headers(self.request, settings.REMOTE_HOST_HEADERS)
|
||||
# Add the reverse lookup of IP addresses.
|
||||
for rh in list(remote_hosts):
|
||||
try:
|
||||
|
||||
@@ -344,22 +344,13 @@ class ApiV2ConfigView(APIView):
|
||||
become_methods=PRIVILEGE_ESCALATION_METHODS,
|
||||
)
|
||||
|
||||
# Check superuser/auditor first
|
||||
if request.user.is_superuser or request.user.is_system_auditor:
|
||||
has_org_access = True
|
||||
else:
|
||||
# Single query checking all three organization role types at once
|
||||
has_org_access = (
|
||||
(
|
||||
Organization.access_qs(request.user, 'change')
|
||||
| Organization.access_qs(request.user, 'audit')
|
||||
| Organization.access_qs(request.user, 'add_project')
|
||||
)
|
||||
.distinct()
|
||||
.exists()
|
||||
)
|
||||
|
||||
if has_org_access:
|
||||
if (
|
||||
request.user.is_superuser
|
||||
or request.user.is_system_auditor
|
||||
or Organization.accessible_objects(request.user, 'admin_role').exists()
|
||||
or Organization.accessible_objects(request.user, 'auditor_role').exists()
|
||||
or Organization.accessible_objects(request.user, 'project_admin_role').exists()
|
||||
):
|
||||
data.update(
|
||||
dict(
|
||||
project_base_dir=settings.PROJECTS_ROOT,
|
||||
@@ -367,10 +358,8 @@ class ApiV2ConfigView(APIView):
|
||||
custom_virtualenvs=get_custom_venv_choices(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only check JobTemplate access if org check failed
|
||||
if JobTemplate.accessible_objects(request.user, 'admin_role').exists():
|
||||
data['custom_virtualenvs'] = get_custom_venv_choices()
|
||||
elif JobTemplate.accessible_objects(request.user, 'admin_role').exists():
|
||||
data['custom_virtualenvs'] = get_custom_venv_choices()
|
||||
|
||||
return Response(data)
|
||||
|
||||
|
||||
@@ -902,10 +902,6 @@ class HostAccess(BaseAccess):
|
||||
)
|
||||
prefetch_related = ('groups', 'inventory_sources')
|
||||
|
||||
def get_queryset(self):
|
||||
qs = super().get_queryset()
|
||||
return qs.exclude(inventory__kind='constructed')
|
||||
|
||||
def filtered_queryset(self):
|
||||
return self.model.objects.filter(inventory__in=Inventory.accessible_pk_qs(self.user, 'read_role'))
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ __all__ = [
|
||||
'CAN_CANCEL',
|
||||
'ACTIVE_STATES',
|
||||
'STANDARD_INVENTORY_UPDATE_ENV',
|
||||
'OIDC_CREDENTIAL_TYPE_NAMESPACES',
|
||||
]
|
||||
|
||||
PRIVILEGE_ESCALATION_METHODS = [
|
||||
@@ -141,6 +140,3 @@ org_role_to_permission = {
|
||||
'execution_environment_admin_role': 'add_executionenvironment',
|
||||
'auditor_role': 'view_project', # TODO: also doesnt really work
|
||||
}
|
||||
|
||||
# OIDC credential type namespaces for feature flag filtering
|
||||
OIDC_CREDENTIAL_TYPE_NAMESPACES = ['hashivault-kv-oidc', 'hashivault-ssh-oidc']
|
||||
|
||||
@@ -27,10 +27,6 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
|
||||
"pool_kwargs": {
|
||||
"min_workers": settings.JOB_EVENT_WORKERS,
|
||||
"max_workers": max_workers,
|
||||
# This must be less than max_workers to make sense, which is usually 4
|
||||
# With reserve of 1, after a burst of tasks, load needs to down to 4-1=3
|
||||
# before we return to min_workers
|
||||
"scaledown_reserve": 1,
|
||||
},
|
||||
"main_kwargs": {"node_id": settings.CLUSTER_HOST_ID},
|
||||
"process_manager_cls": "ForkServerManager",
|
||||
|
||||
@@ -77,13 +77,13 @@ class CallbackBrokerWorker:
|
||||
|
||||
MAX_RETRIES = 2
|
||||
INDIVIDUAL_EVENT_RETRIES = 3
|
||||
last_stats = time.time()
|
||||
last_flush = time.time()
|
||||
total = 0
|
||||
last_event = ''
|
||||
prof = None
|
||||
|
||||
def __init__(self):
|
||||
self.last_stats = time.time()
|
||||
self.last_flush = time.time()
|
||||
self.buff = {}
|
||||
self.redis = get_redis_client()
|
||||
self.subsystem_metrics = s_metrics.CallbackReceiverMetrics(auto_pipe_execute=False)
|
||||
|
||||
@@ -428,9 +428,6 @@ class CredentialInputField(JSONSchemaField):
|
||||
# determine the defined fields for the associated credential type
|
||||
properties = {}
|
||||
for field in model_instance.credential_type.inputs.get('fields', []):
|
||||
# Prevent users from providing values for internally resolved fields
|
||||
if 'internal' in field:
|
||||
continue
|
||||
field = field.copy()
|
||||
properties[field['id']] = field
|
||||
if field.get('choices', []):
|
||||
@@ -569,7 +566,6 @@ class CredentialTypeInputField(JSONSchemaField):
|
||||
},
|
||||
'label': {'type': 'string'},
|
||||
'help_text': {'type': 'string'},
|
||||
'internal': {'type': 'boolean'},
|
||||
'multiline': {'type': 'boolean'},
|
||||
'secret': {'type': 'boolean'},
|
||||
'ask_at_runtime': {'type': 'boolean'},
|
||||
|
||||
@@ -409,12 +409,10 @@ class Command(BaseCommand):
|
||||
del_child_group_pks = list(set(db_children_name_pk_map.values()))
|
||||
for offset in range(0, len(del_child_group_pks), self._batch_size):
|
||||
child_group_pks = del_child_group_pks[offset : (offset + self._batch_size)]
|
||||
children_to_remove = list(db_children.filter(pk__in=child_group_pks))
|
||||
if children_to_remove:
|
||||
group_group_count += len(children_to_remove)
|
||||
db_group.children.remove(*children_to_remove)
|
||||
for db_child in children_to_remove:
|
||||
logger.debug('Group "%s" removed from group "%s"', db_child.name, db_group.name)
|
||||
for db_child in db_children.filter(pk__in=child_group_pks):
|
||||
group_group_count += 1
|
||||
db_group.children.remove(db_child)
|
||||
logger.debug('Group "%s" removed from group "%s"', db_child.name, db_group.name)
|
||||
# FIXME: Inventory source group relationships
|
||||
# Delete group/host relationships not present in imported data.
|
||||
db_hosts = db_group.hosts
|
||||
@@ -443,12 +441,12 @@ class Command(BaseCommand):
|
||||
del_host_pks = list(del_host_pks)
|
||||
for offset in range(0, len(del_host_pks), self._batch_size):
|
||||
del_pks = del_host_pks[offset : (offset + self._batch_size)]
|
||||
hosts_to_remove = list(db_hosts.filter(pk__in=del_pks))
|
||||
if hosts_to_remove:
|
||||
group_host_count += len(hosts_to_remove)
|
||||
db_group.hosts.remove(*hosts_to_remove)
|
||||
for db_host in hosts_to_remove:
|
||||
logger.debug('Host "%s" removed from group "%s"', db_host.name, db_group.name)
|
||||
for db_host in db_hosts.filter(pk__in=del_pks):
|
||||
group_host_count += 1
|
||||
if db_host not in db_group.hosts.all():
|
||||
continue
|
||||
db_group.hosts.remove(db_host)
|
||||
logger.debug('Host "%s" removed from group "%s"', db_host.name, db_group.name)
|
||||
if settings.SQL_DEBUG:
|
||||
logger.warning(
|
||||
'group-group and group-host deletions took %d queries for %d relationships',
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Generated by Django 5.2.8 on 2026-02-20 03:39
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('main', '0204_squashed_deletions'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name='instancegroup',
|
||||
options={
|
||||
'default_permissions': ('change', 'delete', 'view'),
|
||||
'ordering': ('pk',),
|
||||
'permissions': [('use_instancegroup', 'Can use instance group in a preference list of a resource')],
|
||||
},
|
||||
),
|
||||
migrations.AlterModelOptions(
|
||||
name='workflowjobnode',
|
||||
options={'ordering': ('pk',)},
|
||||
),
|
||||
migrations.AlterModelOptions(
|
||||
name='workflowjobtemplatenode',
|
||||
options={'ordering': ('pk',)},
|
||||
),
|
||||
]
|
||||
@@ -28,7 +28,6 @@ from rest_framework.serializers import ValidationError as DRFValidationError
|
||||
from ansible_base.lib.utils.db import advisory_lock
|
||||
|
||||
# AWX
|
||||
from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.fields import (
|
||||
ImplicitRoleField,
|
||||
@@ -243,29 +242,6 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
|
||||
needed.append('vault_password')
|
||||
return needed
|
||||
|
||||
@functools.cached_property
|
||||
def context(self):
|
||||
"""
|
||||
Property for storing runtime context during credential resolution.
|
||||
|
||||
The context is a dict keyed by CredentialInputSource PK, where each value
|
||||
is a dict of runtime fields for that input source. Example::
|
||||
|
||||
{
|
||||
<input_source_pk>: {
|
||||
"workload_identity_token": "<jwt_token>"
|
||||
},
|
||||
<another_input_source_pk>: {
|
||||
"workload_identity_token": "<different_jwt_token>"
|
||||
},
|
||||
}
|
||||
|
||||
This structure allows each input source to have its own set of runtime
|
||||
values, avoiding conflicts when a credential has multiple input sources
|
||||
with different configurations (e.g., different JWT audiences).
|
||||
"""
|
||||
return {}
|
||||
|
||||
@cached_property
|
||||
def dynamic_input_fields(self):
|
||||
# if the credential is not yet saved we can't access the input_sources
|
||||
@@ -391,7 +367,7 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
|
||||
def _get_dynamic_input(self, field_name):
|
||||
for input_source in self.input_sources.all():
|
||||
if input_source.input_field_name == field_name:
|
||||
return input_source.get_input_value(context=self.context)
|
||||
return input_source.get_input_value()
|
||||
else:
|
||||
raise ValueError('{} is not a dynamic input field'.format(field_name))
|
||||
|
||||
@@ -459,15 +435,13 @@ class CredentialType(CommonModelNameNotUnique):
|
||||
def from_db(cls, db, field_names, values):
|
||||
instance = super(CredentialType, cls).from_db(db, field_names, values)
|
||||
if instance.managed and instance.namespace and instance.kind != "external":
|
||||
native = ManagedCredentialType.registry.get(instance.namespace)
|
||||
if native:
|
||||
instance.inputs = native.inputs
|
||||
instance.injectors = native.injectors
|
||||
instance.custom_injectors = getattr(native, 'custom_injectors', None)
|
||||
native = ManagedCredentialType.registry[instance.namespace]
|
||||
instance.inputs = native.inputs
|
||||
instance.injectors = native.injectors
|
||||
instance.custom_injectors = getattr(native, 'custom_injectors', None)
|
||||
elif instance.namespace and instance.kind == "external":
|
||||
native = ManagedCredentialType.registry.get(instance.namespace)
|
||||
if native:
|
||||
instance.inputs = native.inputs
|
||||
native = ManagedCredentialType.registry[instance.namespace]
|
||||
instance.inputs = native.inputs
|
||||
|
||||
return instance
|
||||
|
||||
@@ -531,7 +505,6 @@ class CredentialType(CommonModelNameNotUnique):
|
||||
existing = ct_class.objects.filter(name=default.name, kind=default.kind).first()
|
||||
if existing is not None:
|
||||
existing.namespace = default.namespace
|
||||
existing.description = getattr(default, 'description', '')
|
||||
existing.inputs = {}
|
||||
existing.injectors = {}
|
||||
existing.save()
|
||||
@@ -571,14 +544,7 @@ class CredentialType(CommonModelNameNotUnique):
|
||||
@classmethod
|
||||
def load_plugin(cls, ns, plugin):
|
||||
# TODO: User "side-loaded" credential custom_injectors isn't supported
|
||||
ManagedCredentialType.registry[ns] = SimpleNamespace(
|
||||
namespace=ns,
|
||||
name=plugin.name,
|
||||
kind='external',
|
||||
inputs=plugin.inputs,
|
||||
backend=plugin.backend,
|
||||
description=getattr(plugin, 'plugin_description', ''),
|
||||
)
|
||||
ManagedCredentialType.registry[ns] = SimpleNamespace(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs, backend=plugin.backend)
|
||||
|
||||
def inject_credential(self, credential, env, safe_env, args, private_data_dir, container_root=None):
|
||||
from awx_plugins.interfaces._temporary_private_inject_api import inject_credential
|
||||
@@ -590,13 +556,7 @@ class CredentialTypeHelper:
|
||||
@classmethod
|
||||
def get_creation_params(cls, cred_type):
|
||||
if cred_type.kind == 'external':
|
||||
return {
|
||||
'namespace': cred_type.namespace,
|
||||
'kind': cred_type.kind,
|
||||
'name': cred_type.name,
|
||||
'managed': True,
|
||||
'description': getattr(cred_type, 'description', ''),
|
||||
}
|
||||
return dict(namespace=cred_type.namespace, kind=cred_type.kind, name=cred_type.name, managed=True)
|
||||
return dict(
|
||||
namespace=cred_type.namespace,
|
||||
kind=cred_type.kind,
|
||||
@@ -662,15 +622,7 @@ class CredentialInputSource(PrimordialModel):
|
||||
raise ValidationError(_('Input field must be defined on target credential (options are {}).'.format(', '.join(sorted(defined_fields)))))
|
||||
return self.input_field_name
|
||||
|
||||
def get_input_value(self, context: dict | None = None):
|
||||
"""
|
||||
Retrieve the value from the external credential backend.
|
||||
|
||||
Args:
|
||||
context: Optional runtime context dict passed from the target credential.
|
||||
"""
|
||||
if context is None:
|
||||
context = {}
|
||||
def get_input_value(self):
|
||||
backend = self.source_credential.credential_type.plugin.backend
|
||||
backend_kwargs = {}
|
||||
for field_name, value in self.source_credential.inputs.items():
|
||||
@@ -681,17 +633,6 @@ class CredentialInputSource(PrimordialModel):
|
||||
|
||||
backend_kwargs.update(self.metadata)
|
||||
|
||||
# Resolve internal fields from the per-input-source context.
|
||||
# The context dict is keyed by input source PK, e.g.:
|
||||
# {42: {"workload_identity_token": "eyJ..."}, 43: {"workload_identity_token": "eyX..."}}
|
||||
# This allows each input source to carry its own runtime values.
|
||||
input_source_context = context.get(self.pk, {})
|
||||
for field in self.source_credential.credential_type.inputs.get('fields', []):
|
||||
if field.get('internal'):
|
||||
value = input_source_context.get(field['id'])
|
||||
if value is not None:
|
||||
backend_kwargs[field['id']] = value
|
||||
|
||||
with set_environ(**settings.AWX_TASK_ENV):
|
||||
return backend(**backend_kwargs)
|
||||
|
||||
@@ -700,20 +641,13 @@ class CredentialInputSource(PrimordialModel):
|
||||
return reverse(view_name, kwargs={'pk': self.pk}, request=request)
|
||||
|
||||
|
||||
def _is_oidc_namespace_disabled(ns):
|
||||
"""Check if a credential namespace should be skipped based on the OIDC feature flag."""
|
||||
return ns in OIDC_CREDENTIAL_TYPE_NAMESPACES and not getattr(settings, 'FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED', False)
|
||||
|
||||
|
||||
def load_credentials():
|
||||
|
||||
awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')}
|
||||
supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')}
|
||||
plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points}
|
||||
|
||||
for ns, ep in plugin_entry_points.items():
|
||||
if _is_oidc_namespace_disabled(ns):
|
||||
continue
|
||||
|
||||
cred_plugin = ep.load()
|
||||
if not hasattr(cred_plugin, 'inputs'):
|
||||
setattr(cred_plugin, 'inputs', {})
|
||||
@@ -732,8 +666,5 @@ def load_credentials():
|
||||
credential_plugins = {}
|
||||
|
||||
for ns, ep in credential_plugins.items():
|
||||
if _is_oidc_namespace_disabled(ns):
|
||||
continue
|
||||
|
||||
plugin = ep.load()
|
||||
CredentialType.load_plugin(ns, plugin)
|
||||
|
||||
@@ -485,7 +485,6 @@ class InstanceGroup(HasPolicyEditsMixin, BaseModel, RelatedJobsMixin, ResourceMi
|
||||
|
||||
class Meta:
|
||||
app_label = 'main'
|
||||
ordering = ('pk',)
|
||||
permissions = [('use_instancegroup', 'Can use instance group in a preference list of a resource')]
|
||||
# Since this has no direct organization field only superuser can add, so remove add permission
|
||||
default_permissions = ('change', 'delete', 'view')
|
||||
|
||||
@@ -845,21 +845,6 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
|
||||
def get_notification_friendly_name(self):
|
||||
return "Job"
|
||||
|
||||
def get_source_hosts_for_constructed_inventory(self):
|
||||
"""Return a QuerySet of the source (input inventory) hosts for a constructed inventory.
|
||||
|
||||
Constructed inventory hosts have an instance_id pointing to the real
|
||||
host in the input inventory. This resolves those references and returns
|
||||
a proper QuerySet (never a list), suitable for use with finish_fact_cache.
|
||||
"""
|
||||
Host = JobHostSummary._meta.get_field('host').related_model
|
||||
if not self.inventory_id:
|
||||
return Host.objects.none()
|
||||
id_field = Host._meta.get_field('id')
|
||||
return Host.objects.filter(id__in=self.inventory.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field))).only(
|
||||
*HOST_FACTS_FIELDS
|
||||
)
|
||||
|
||||
def get_hosts_for_fact_cache(self):
|
||||
"""
|
||||
Builds the queryset to use for writing or finalizing the fact cache
|
||||
@@ -867,15 +852,17 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
|
||||
For constructed inventories, that means the original (input inventory) hosts
|
||||
when slicing, that means only returning hosts in that slice
|
||||
"""
|
||||
Host = JobHostSummary._meta.get_field('host').related_model
|
||||
if not self.inventory_id:
|
||||
Host = JobHostSummary._meta.get_field('host').related_model
|
||||
return Host.objects.none()
|
||||
|
||||
if self.inventory.kind == 'constructed':
|
||||
host_qs = self.get_source_hosts_for_constructed_inventory()
|
||||
id_field = Host._meta.get_field('id')
|
||||
host_qs = Host.objects.filter(id__in=self.inventory.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field)))
|
||||
else:
|
||||
host_qs = self.inventory.hosts.only(*HOST_FACTS_FIELDS)
|
||||
host_qs = self.inventory.hosts
|
||||
|
||||
host_qs = host_qs.only(*HOST_FACTS_FIELDS)
|
||||
host_qs = self.inventory.get_sliced_hosts(host_qs, self.job_slice_number, self.job_slice_count)
|
||||
return host_qs
|
||||
|
||||
|
||||
@@ -72,10 +72,10 @@ def _fast_forward_rrule(rrule, ref_dt=None):
|
||||
if ref_dt is None:
|
||||
ref_dt = now()
|
||||
|
||||
dtstart_tz = rrule._dtstart.tzinfo
|
||||
ref_dt = ref_dt.astimezone(dtstart_tz)
|
||||
ref_dt = ref_dt.astimezone(datetime.timezone.utc)
|
||||
|
||||
if rrule._dtstart > ref_dt:
|
||||
rrule_dtstart_utc = rrule._dtstart.astimezone(datetime.timezone.utc)
|
||||
if rrule_dtstart_utc > ref_dt:
|
||||
return rrule
|
||||
|
||||
interval = rrule._interval if rrule._interval else 1
|
||||
@@ -84,14 +84,20 @@ def _fast_forward_rrule(rrule, ref_dt=None):
|
||||
elif rrule._freq == dateutil.rrule.MINUTELY:
|
||||
interval *= 60
|
||||
|
||||
# if after converting to seconds the interval is still a fraction,
|
||||
# just return original rrule
|
||||
if isinstance(interval, float) and not interval.is_integer():
|
||||
return rrule
|
||||
|
||||
seconds_since_dtstart = (ref_dt - rrule._dtstart).total_seconds()
|
||||
seconds_since_dtstart = (ref_dt - rrule_dtstart_utc).total_seconds()
|
||||
|
||||
# it is important to fast forward by a number that is divisible by
|
||||
# interval. For example, if interval is 7 hours, we fast forward by 7, 14, 21, etc. hours.
|
||||
# Otherwise, the occurrences after the fast forward might not match the ones before.
|
||||
# x // y is integer division, lopping off any remainder, so that we get the outcome we want.
|
||||
interval_aligned_offset = datetime.timedelta(seconds=(seconds_since_dtstart // interval) * interval)
|
||||
new_start = rrule._dtstart + interval_aligned_offset
|
||||
new_rrule = rrule.replace(dtstart=new_start)
|
||||
new_start = rrule_dtstart_utc + interval_aligned_offset
|
||||
new_rrule = rrule.replace(dtstart=new_start.astimezone(rrule._dtstart.tzinfo))
|
||||
return new_rrule
|
||||
|
||||
|
||||
|
||||
@@ -200,7 +200,6 @@ class WorkflowJobTemplateNode(WorkflowNodeBase):
|
||||
indexes = [
|
||||
models.Index(fields=['identifier']),
|
||||
]
|
||||
ordering = ('pk',)
|
||||
|
||||
def get_absolute_url(self, request=None):
|
||||
return reverse('api:workflow_job_template_node_detail', kwargs={'pk': self.pk}, request=request)
|
||||
@@ -287,7 +286,6 @@ class WorkflowJobNode(WorkflowNodeBase):
|
||||
models.Index(fields=["identifier", "workflow_job"]),
|
||||
models.Index(fields=['identifier']),
|
||||
]
|
||||
ordering = ('pk',)
|
||||
|
||||
@property
|
||||
def event_processing_finished(self):
|
||||
@@ -345,11 +343,7 @@ class WorkflowJobNode(WorkflowNodeBase):
|
||||
)
|
||||
data.update(accepted_fields) # missing fields are handled in the scheduler
|
||||
# build ancestor artifacts, save them to node model for later
|
||||
# initialize from pre-seeded ancestor_artifacts (set on root nodes of
|
||||
# child workflows via seed_root_ancestor_artifacts to carry artifacts
|
||||
# from the parent workflow); exclude job_slice which is internal
|
||||
# metadata handled separately below
|
||||
aa_dict = {k: v for k, v in self.ancestor_artifacts.items() if k != 'job_slice'} if self.ancestor_artifacts else {}
|
||||
aa_dict = {}
|
||||
is_root_node = True
|
||||
for parent_node in self.get_parent_nodes():
|
||||
is_root_node = False
|
||||
@@ -370,13 +364,11 @@ class WorkflowJobNode(WorkflowNodeBase):
|
||||
data['survey_passwords'] = password_dict
|
||||
# process extra_vars
|
||||
extra_vars = data.get('extra_vars', {})
|
||||
if ujt_obj and isinstance(ujt_obj, JobTemplate):
|
||||
if ujt_obj and isinstance(ujt_obj, (JobTemplate, WorkflowJobTemplate)):
|
||||
if aa_dict:
|
||||
functional_aa_dict = copy(aa_dict)
|
||||
functional_aa_dict.pop('_ansible_no_log', None)
|
||||
extra_vars.update(functional_aa_dict)
|
||||
elif ujt_obj and isinstance(ujt_obj, WorkflowJobTemplate):
|
||||
pass # artifacts are applied via seed_root_ancestor_artifacts in the task manager
|
||||
|
||||
# Workflow Job extra_vars higher precedence than ancestor artifacts
|
||||
extra_vars.update(wj_special_vars)
|
||||
@@ -740,18 +732,6 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio
|
||||
wj = wj.get_workflow_job()
|
||||
return ancestors
|
||||
|
||||
def seed_root_ancestor_artifacts(self, artifacts):
|
||||
"""Apply parent workflow artifacts to root nodes so they propagate
|
||||
through the normal ancestor_artifacts channel instead of being
|
||||
baked into this workflow's extra_vars."""
|
||||
self.workflow_job_nodes.exclude(
|
||||
workflowjobnodes_success__isnull=False,
|
||||
).exclude(
|
||||
workflowjobnodes_failure__isnull=False,
|
||||
).exclude(
|
||||
workflowjobnodes_always__isnull=False,
|
||||
).update(ancestor_artifacts=artifacts)
|
||||
|
||||
def get_effective_artifacts(self, **kwargs):
|
||||
"""
|
||||
For downstream jobs of a workflow nested inside of a workflow,
|
||||
@@ -936,17 +916,6 @@ class WorkflowApproval(UnifiedJob, JobNotificationMixin):
|
||||
ScheduleWorkflowManager().schedule()
|
||||
return reverse('api:workflow_approval_deny', kwargs={'pk': self.pk}, request=request)
|
||||
|
||||
def cancel(self, job_explanation=None, is_chain=False):
|
||||
# WorkflowApprovals have no dispatcher process (they wait for human
|
||||
# input) and are excluded from TaskManager processing, so the base
|
||||
# cancel() would only set cancel_flag without ever transitioning the
|
||||
# status. We call super() for the flag, then transition directly.
|
||||
has_already_canceled = bool(self.status == 'canceled')
|
||||
super().cancel(job_explanation=job_explanation, is_chain=is_chain)
|
||||
if self.status != 'canceled' and not has_already_canceled:
|
||||
self.status = 'canceled'
|
||||
self.save(update_fields=['status'])
|
||||
|
||||
def signal_start(self, **kwargs):
|
||||
can_start = super(WorkflowApproval, self).signal_start(**kwargs)
|
||||
self.started = self.created
|
||||
|
||||
@@ -19,8 +19,13 @@ class ActivityStreamRegistrar(object):
|
||||
pre_delete.connect(activity_stream_delete, sender=model, dispatch_uid=str(self.__class__) + str(model) + "_delete")
|
||||
|
||||
for m2mfield in model._meta.many_to_many:
|
||||
m2m_attr = getattr(model, m2mfield.name)
|
||||
m2m_changed.connect(activity_stream_associate, sender=m2m_attr.through, dispatch_uid=str(self.__class__) + str(m2m_attr.through) + "_associate")
|
||||
try:
|
||||
m2m_attr = getattr(model, m2mfield.name)
|
||||
m2m_changed.connect(
|
||||
activity_stream_associate, sender=m2m_attr.through, dispatch_uid=str(self.__class__) + str(m2m_attr.through) + "_associate"
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def disconnect(self, model):
|
||||
if model in self.models:
|
||||
|
||||
@@ -48,6 +48,11 @@ class SimpleDAG(object):
|
||||
'''
|
||||
self.node_to_edges_by_label = dict()
|
||||
|
||||
def __contains__(self, obj):
|
||||
if self.node['node_object'] in self.node_obj_to_node_index:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes)
|
||||
|
||||
|
||||
@@ -122,11 +122,8 @@ class WorkflowDAG(SimpleDAG):
|
||||
if not job:
|
||||
continue
|
||||
elif job.can_cancel:
|
||||
cancel_finished = False
|
||||
job.cancel()
|
||||
# If the job is not yet in a terminal state after .cancel(),
|
||||
# the TaskManager still needs to process it.
|
||||
if job.status not in ('successful', 'failed', 'canceled', 'error'):
|
||||
cancel_finished = False
|
||||
return cancel_finished
|
||||
|
||||
def is_workflow_done(self):
|
||||
|
||||
@@ -196,10 +196,6 @@ class WorkflowManager(TaskBase):
|
||||
workflow_job.start_args = '' # blank field to remove encrypted passwords
|
||||
workflow_job.save(update_fields=['status', 'start_args'])
|
||||
status_changed = True
|
||||
else:
|
||||
# Speed-up: schedule the task manager so it can process the
|
||||
# canceled pending jobs without waiting for the next cycle.
|
||||
ScheduleTaskManager().schedule()
|
||||
else:
|
||||
dnr_nodes = dag.mark_dnr_nodes()
|
||||
WorkflowJobNode.objects.bulk_update(dnr_nodes, ['do_not_run'])
|
||||
@@ -241,8 +237,6 @@ class WorkflowManager(TaskBase):
|
||||
job = spawn_node.unified_job_template.create_unified_job(**kv)
|
||||
spawn_node.job = job
|
||||
spawn_node.save()
|
||||
if spawn_node.ancestor_artifacts and isinstance(spawn_node.unified_job_template, WorkflowJobTemplate):
|
||||
job.seed_root_ancestor_artifacts(spawn_node.ancestor_artifacts)
|
||||
logger.debug('Spawned %s in %s for node %s', job.log_format, workflow_job.log_format, spawn_node.pk)
|
||||
can_start = True
|
||||
if isinstance(spawn_node.unified_job_template, WorkflowJobTemplate):
|
||||
@@ -449,29 +443,17 @@ class TaskManager(TaskBase):
|
||||
self.controlplane_ig = self.tm_models.instance_groups.controlplane_ig
|
||||
|
||||
def process_job_dep_failures(self, task):
|
||||
"""If job depends on a job that has failed or been canceled, mark as failed.
|
||||
|
||||
Returns True if a dep failure was found, False otherwise.
|
||||
"""
|
||||
"""If job depends on a job that has failed, mark as failed and handle misc stuff."""
|
||||
for dep in task.dependent_jobs.all():
|
||||
# if we detect a failed, error, or canceled dependency, go ahead and fail this task.
|
||||
if dep.status in ("error", "failed", "canceled"):
|
||||
# if we detect a failed or error dependency, go ahead and fail this task.
|
||||
if dep.status in ("error", "failed"):
|
||||
task.status = 'failed'
|
||||
if dep.status == 'canceled':
|
||||
logger.warning(f'Previous task canceled, failing task: {task.id} dep: {dep.id} task manager')
|
||||
task.job_explanation = 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
|
||||
get_type_for_model(type(dep)),
|
||||
dep.name,
|
||||
dep.id,
|
||||
)
|
||||
ScheduleWorkflowManager().schedule() # speedup for dependency chains in workflow, on workflow cancel
|
||||
else:
|
||||
logger.warning(f'Previous task failed, failing task: {task.id} dep: {dep.id} task manager')
|
||||
task.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
|
||||
get_type_for_model(type(dep)),
|
||||
dep.name,
|
||||
dep.id,
|
||||
)
|
||||
logger.warning(f'Previous task failed task: {task.id} dep: {dep.id} task manager')
|
||||
task.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (
|
||||
get_type_for_model(type(dep)),
|
||||
dep.name,
|
||||
dep.id,
|
||||
)
|
||||
task.save(update_fields=['status', 'job_explanation'])
|
||||
task.websocket_emit_status('failed')
|
||||
self.pre_start_failed.append(task.id)
|
||||
@@ -563,17 +545,8 @@ class TaskManager(TaskBase):
|
||||
logger.warning("Task manager has reached time out while processing pending jobs, exiting loop early")
|
||||
break
|
||||
|
||||
if task.cancel_flag:
|
||||
logger.debug(f"Canceling pending task {task.log_format} because cancel_flag is set")
|
||||
task.status = 'canceled'
|
||||
task.job_explanation = gettext_noop("This job was canceled before it started.")
|
||||
task.save(update_fields=['status', 'job_explanation'])
|
||||
task.websocket_emit_status('canceled')
|
||||
self.pre_start_failed.append(task.id)
|
||||
ScheduleWorkflowManager().schedule()
|
||||
continue
|
||||
|
||||
if self.process_job_dep_failures(task):
|
||||
has_failed = self.process_job_dep_failures(task)
|
||||
if has_failed:
|
||||
continue
|
||||
|
||||
blocked_by = self.job_blocked_by(task)
|
||||
|
||||
@@ -25,8 +25,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
|
||||
log_data = log_data or {}
|
||||
log_data['inventory_id'] = inventory_id
|
||||
log_data['written_ct'] = 0
|
||||
# Dict mapping host name -> bool (True if a fact file was written)
|
||||
hosts_cached = {}
|
||||
hosts_cached = []
|
||||
|
||||
# Create the fact_cache directory inside artifacts_dir
|
||||
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
|
||||
@@ -38,14 +37,13 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
|
||||
last_write_time = None
|
||||
|
||||
for host in hosts:
|
||||
hosts_cached.append(host.name)
|
||||
if not host.ansible_facts_modified or (timeout and host.ansible_facts_modified < now() - datetime.timedelta(seconds=timeout)):
|
||||
hosts_cached[host.name] = False
|
||||
continue # facts are expired - do not write them
|
||||
|
||||
filepath = os.path.join(fact_cache_dir, host.name)
|
||||
if not os.path.realpath(filepath).startswith(fact_cache_dir):
|
||||
logger.error(f'facts for host {smart_str(host.name)} could not be cached')
|
||||
hosts_cached[host.name] = False
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -53,18 +51,9 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
|
||||
os.chmod(f.name, 0o600)
|
||||
json.dump(host.ansible_facts, f)
|
||||
log_data['written_ct'] += 1
|
||||
# Backdate the file by 2 seconds so finish_fact_cache can reliably
|
||||
# distinguish these reference files from files updated by ansible.
|
||||
# This guarantees fact file mtime < summary file mtime even with
|
||||
# zipfile's 2-second timestamp rounding during artifact transfer.
|
||||
mtime = os.path.getmtime(filepath)
|
||||
backdated = mtime - 2
|
||||
os.utime(filepath, (backdated, backdated))
|
||||
last_write_time = backdated
|
||||
hosts_cached[host.name] = True
|
||||
last_write_time = os.path.getmtime(filepath)
|
||||
except IOError:
|
||||
logger.error(f'facts for host {smart_str(host.name)} could not be cached')
|
||||
hosts_cached[host.name] = False
|
||||
continue
|
||||
|
||||
# Write summary file directly to the artifacts_dir
|
||||
@@ -73,6 +62,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
|
||||
summary_data = {
|
||||
'last_write_time': last_write_time,
|
||||
'hosts_cached': hosts_cached,
|
||||
'written_ct': log_data['written_ct'],
|
||||
}
|
||||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary_data, f, indent=2)
|
||||
@@ -84,7 +74,7 @@ def start_fact_cache(hosts, artifacts_dir, timeout=None, inventory_id=None, log_
|
||||
msg='Inventory {inventory_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s',
|
||||
add_log_data=True,
|
||||
)
|
||||
def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, job_created=None, log_data=None):
|
||||
def finish_fact_cache(artifacts_dir, job_id=None, inventory_id=None, log_data=None):
|
||||
log_data = log_data or {}
|
||||
log_data['inventory_id'] = inventory_id
|
||||
log_data['updated_ct'] = 0
|
||||
@@ -104,9 +94,8 @@ def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, jo
|
||||
logger.error(f'Error reading summary file at {summary_path}: {e}')
|
||||
return
|
||||
|
||||
hosts_cached_map = summary.get('hosts_cached', {})
|
||||
host_names = list(hosts_cached_map.keys())
|
||||
hosts_cached = host_qs.filter(name__in=host_names).order_by('id').iterator()
|
||||
host_names = summary.get('hosts_cached', [])
|
||||
hosts_cached = Host.objects.filter(name__in=host_names).order_by('id').iterator()
|
||||
# Path where individual fact files were written
|
||||
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
|
||||
hosts_to_update = []
|
||||
@@ -147,35 +136,16 @@ def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, jo
|
||||
else:
|
||||
log_data['unmodified_ct'] += 1
|
||||
else:
|
||||
# File is missing. Only interpret this as "ansible cleared facts" if
|
||||
# start_fact_cache actually wrote a file for this host (i.e. the host
|
||||
# had valid, non-expired facts before the job ran). If no file was
|
||||
# ever written, the missing file is expected and not a clear signal.
|
||||
if not hosts_cached_map.get(host.name):
|
||||
log_data['unmodified_ct'] += 1
|
||||
continue
|
||||
|
||||
# if the file goes missing, ansible removed it (likely via clear_facts)
|
||||
# if the file goes missing, but the host has not started facts, then we should not clear the facts
|
||||
if job_created and host.ansible_facts_modified and host.ansible_facts_modified > job_created:
|
||||
logger.warning(
|
||||
f'Skipping fact clear for host {smart_str(host.name)} in job {job_id} '
|
||||
f'inventory {inventory_id}: host ansible_facts_modified '
|
||||
f'({host.ansible_facts_modified.isoformat()}) is after this job\'s '
|
||||
f'created time ({job_created.isoformat()}). '
|
||||
f'A concurrent job likely updated this host\'s facts while this job was running.'
|
||||
)
|
||||
log_data['unmodified_ct'] += 1
|
||||
else:
|
||||
host.ansible_facts = {}
|
||||
host.ansible_facts_modified = now()
|
||||
hosts_to_update.append(host)
|
||||
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
|
||||
log_data['cleared_ct'] += 1
|
||||
host.ansible_facts = {}
|
||||
host.ansible_facts_modified = now()
|
||||
hosts_to_update.append(host)
|
||||
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
|
||||
log_data['cleared_ct'] += 1
|
||||
|
||||
if len(hosts_to_update) >= 100:
|
||||
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
hosts_to_update = []
|
||||
|
||||
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
logger.debug(f'Updated {log_data["updated_ct"]} host facts for inventory {inventory_id} in job {job_id}')
|
||||
|
||||
@@ -17,6 +17,7 @@ import urllib.parse as urlparse
|
||||
|
||||
# Django
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
|
||||
# Shared code for the AWX platform
|
||||
from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT, get_incontainer_path
|
||||
@@ -94,7 +95,6 @@ from flags.state import flag_enabled
|
||||
|
||||
# Workload Identity
|
||||
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
|
||||
from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims
|
||||
|
||||
logger = logging.getLogger('awx.main.tasks.jobs')
|
||||
|
||||
@@ -104,6 +104,11 @@ def populate_claims_for_workload(unified_job) -> dict:
|
||||
Extract JWT claims from a Controller workload for the aap_controller_automation_job scope.
|
||||
"""
|
||||
|
||||
# Related objects in the UnifiedJob model, applies to all job types
|
||||
organization = getattr_dne(unified_job, 'organization')
|
||||
ujt = getattr_dne(unified_job, 'unified_job_template')
|
||||
instance_group = getattr_dne(unified_job, 'instance_group')
|
||||
|
||||
claims = {
|
||||
AutomationControllerJobScope.CLAIM_JOB_ID: unified_job.id,
|
||||
AutomationControllerJobScope.CLAIM_JOB_NAME: unified_job.name,
|
||||
@@ -158,24 +163,6 @@ def populate_claims_for_workload(unified_job) -> dict:
|
||||
return claims
|
||||
|
||||
|
||||
def retrieve_workload_identity_jwt(
|
||||
unified_job: UnifiedJob,
|
||||
audience: str,
|
||||
scope: str,
|
||||
workload_ttl_seconds: int | None = None,
|
||||
) -> str:
|
||||
"""Retrieve JWT token from workload claims.
|
||||
Raises:
|
||||
RuntimeError: if the workload identity client is not configured.
|
||||
"""
|
||||
return retrieve_workload_identity_jwt_with_claims(
|
||||
populate_claims_for_workload(unified_job),
|
||||
audience,
|
||||
scope,
|
||||
workload_ttl_seconds,
|
||||
)
|
||||
|
||||
|
||||
def with_path_cleanup(f):
|
||||
@functools.wraps(f)
|
||||
def _wrapped(self, *args, **kwargs):
|
||||
@@ -202,7 +189,6 @@ def dispatch_waiting_jobs(binder):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
binder.control('run', data={'task': serialize_task(uj._get_task_class()), 'args': [uj.id], 'kwargs': kwargs, 'uuid': uj.celery_task_id})
|
||||
UnifiedJob.objects.filter(pk=uj.pk, status='waiting').update(status='running', start_args='')
|
||||
|
||||
|
||||
class BaseTask(object):
|
||||
@@ -217,63 +203,6 @@ class BaseTask(object):
|
||||
self.update_attempts = int(getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE) / 5)
|
||||
self.runner_callback = self.callback_class(model=self.model)
|
||||
|
||||
@functools.cached_property
|
||||
def _credentials(self):
|
||||
"""
|
||||
Credentials for the task execution.
|
||||
Fetches credentials once using build_credentials_list() and stores
|
||||
them for the duration of the task to avoid redundant database queries.
|
||||
"""
|
||||
credentials_list = self.build_credentials_list(self.instance)
|
||||
# Convert to list to prevent re-evaluation of QuerySet
|
||||
return list(credentials_list)
|
||||
|
||||
def populate_workload_identity_tokens(self, additional_credentials=None):
|
||||
"""
|
||||
Populate credentials with workload identity tokens.
|
||||
|
||||
Sets the context on Credential objects that have input sources
|
||||
using compatible external credential types.
|
||||
"""
|
||||
credentials = list(self._credentials)
|
||||
if additional_credentials:
|
||||
credentials.extend(additional_credentials)
|
||||
credential_input_sources = (
|
||||
(credential.context, src)
|
||||
for credential in credentials
|
||||
for src in credential.input_sources.all()
|
||||
if any(
|
||||
field.get('id') == 'workload_identity_token' and field.get('internal')
|
||||
for field in src.source_credential.credential_type.inputs.get('fields', [])
|
||||
)
|
||||
)
|
||||
for credential_ctx, input_src in credential_input_sources:
|
||||
if flag_enabled("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED"):
|
||||
effective_timeout = self.get_instance_timeout(self.instance)
|
||||
workload_ttl = effective_timeout if effective_timeout else None
|
||||
try:
|
||||
jwt = retrieve_workload_identity_jwt(
|
||||
self.instance,
|
||||
audience=input_src.source_credential.get_input('jwt_aud'),
|
||||
scope=AutomationControllerJobScope.name,
|
||||
workload_ttl_seconds=workload_ttl,
|
||||
)
|
||||
# Store token keyed by input source PK, since a credential can have
|
||||
# multiple input sources (one per field), each potentially with a different audience
|
||||
credential_ctx[input_src.pk] = {"workload_identity_token": jwt}
|
||||
except Exception as e:
|
||||
self.instance.job_explanation = (
|
||||
f'Could not generate workload identity token for credential {input_src.source_credential.name} used in this job. Error:\n{e}'
|
||||
)
|
||||
self.instance.status = 'error'
|
||||
self.instance.save()
|
||||
else:
|
||||
self.instance.job_explanation = (
|
||||
f'Flag FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is not enabled, required for credential {input_src.source_credential.name} used in this job.'
|
||||
)
|
||||
self.instance.status = 'error'
|
||||
self.instance.save()
|
||||
|
||||
def update_model(self, pk, _attempt=0, **updates):
|
||||
return update_model(self.model, pk, _attempt=0, _max_attempts=self.update_attempts, **updates)
|
||||
|
||||
@@ -425,19 +354,6 @@ class BaseTask(object):
|
||||
private_data_files['credentials'][credential] = self.write_private_data_file(private_data_dir, None, data, sub_dir='env')
|
||||
for credential, data in private_data.get('certificates', {}).items():
|
||||
self.write_private_data_file(private_data_dir, 'ssh_key_data-cert.pub', data, sub_dir=os.path.join('artifacts', str(self.instance.id)))
|
||||
|
||||
# Copy vendor collections to private_data_dir for indirect node counting
|
||||
# This makes external query files available to the callback plugin in EEs
|
||||
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
|
||||
vendor_src = '/var/lib/awx/vendor_collections'
|
||||
vendor_dest = os.path.join(private_data_dir, 'vendor_collections')
|
||||
if os.path.exists(vendor_src):
|
||||
try:
|
||||
shutil.copytree(vendor_src, vendor_dest)
|
||||
logger.debug(f"Copied vendor collections from {vendor_src} to {vendor_dest}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to copy vendor collections: {e}")
|
||||
|
||||
return private_data_files, ssh_key_data
|
||||
|
||||
def build_passwords(self, instance, runtime_passwords):
|
||||
@@ -511,7 +427,6 @@ class BaseTask(object):
|
||||
return []
|
||||
|
||||
def get_instance_timeout(self, instance):
|
||||
"""Return the effective job timeout in seconds."""
|
||||
global_timeout_setting_name = instance._global_timeout_setting()
|
||||
if global_timeout_setting_name:
|
||||
global_timeout = getattr(settings, global_timeout_setting_name, 0)
|
||||
@@ -620,32 +535,48 @@ class BaseTask(object):
|
||||
def should_use_fact_cache(self):
|
||||
return False
|
||||
|
||||
def transition_status(self, pk: int) -> bool:
|
||||
"""Atomically transition status to running, if False returned, another process got it"""
|
||||
with transaction.atomic():
|
||||
# Explanation of parts for the fetch:
|
||||
# .values - avoid loading a full object, this is known to lead to deadlocks due to signals
|
||||
# the signals load other related rows which another process may be locking, and happens in practice
|
||||
# of=('self',) - keeps FK tables out of the lock list, another way deadlocks can happen
|
||||
# .get - just load the single job
|
||||
instance_data = UnifiedJob.objects.select_for_update(of=('self',)).values('status', 'cancel_flag').get(pk=pk)
|
||||
|
||||
# If status is not waiting (obtained under lock) then this process does not have clearence to run
|
||||
if instance_data['status'] == 'waiting':
|
||||
if instance_data['cancel_flag']:
|
||||
updated_status = 'canceled'
|
||||
else:
|
||||
updated_status = 'running'
|
||||
# Explanation of the update:
|
||||
# .filter - again, do not load the full object
|
||||
# .update - a bulk update on just that one row, avoid loading unintended data
|
||||
UnifiedJob.objects.filter(pk=pk).update(status=updated_status, start_args='')
|
||||
elif instance_data['status'] == 'running':
|
||||
logger.info(f'Job {pk} is being ran by another process, exiting')
|
||||
return False
|
||||
return True
|
||||
|
||||
@with_path_cleanup
|
||||
@with_signal_handling
|
||||
def run(self, pk, **kwargs):
|
||||
"""
|
||||
Run the job/task and capture its output.
|
||||
"""
|
||||
|
||||
if not self.instance: # Used to skip fetch for local runs
|
||||
# Load the instance
|
||||
self.instance = self.update_model(pk)
|
||||
|
||||
# status should be "running" from dispatch_waiting_jobs,
|
||||
# but may still be "waiting" if the worker picked this up before the status update landed.
|
||||
if self.instance.status == 'waiting':
|
||||
UnifiedJob.objects.filter(pk=pk).update(status="running", start_args='')
|
||||
self.instance.refresh_from_db()
|
||||
if not self.transition_status(pk):
|
||||
logger.info(f'Job {pk} is being ran by another process, exiting')
|
||||
return
|
||||
|
||||
# Load the instance
|
||||
self.instance = self.update_model(pk)
|
||||
if self.instance.status != 'running':
|
||||
logger.error(f'Not starting {self.instance.status} task pk={pk} because its status "{self.instance.status}" is not expected')
|
||||
return
|
||||
|
||||
if self.instance.cancel_flag:
|
||||
self.instance = self.update_model(pk, status='canceled')
|
||||
self.instance.websocket_emit_status('canceled')
|
||||
return
|
||||
|
||||
self.instance.websocket_emit_status("running")
|
||||
status, rc = 'error', None
|
||||
self.runner_callback.event_ct = 0
|
||||
@@ -684,12 +615,6 @@ class BaseTask(object):
|
||||
if not os.path.exists(settings.AWX_ISOLATION_BASE_PATH):
|
||||
raise RuntimeError('AWX_ISOLATION_BASE_PATH=%s does not exist' % settings.AWX_ISOLATION_BASE_PATH)
|
||||
|
||||
if flag_enabled("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED"):
|
||||
logger.info(f'Generating workload identity tokens for {self.instance.log_format}')
|
||||
self.populate_workload_identity_tokens()
|
||||
if self.instance.status == 'error':
|
||||
raise RuntimeError('not starting %s task' % self.instance.status)
|
||||
|
||||
# May have to serialize the value
|
||||
private_data_files, ssh_key_data = self.build_private_data_files(self.instance, private_data_dir)
|
||||
passwords = self.build_passwords(self.instance, kwargs)
|
||||
@@ -707,7 +632,7 @@ class BaseTask(object):
|
||||
|
||||
self.runner_callback.job_created = str(self.instance.created)
|
||||
|
||||
credentials = self._credentials
|
||||
credentials = self.build_credentials_list(self.instance)
|
||||
|
||||
container_root = None
|
||||
if settings.IS_K8S and isinstance(self.instance, ProjectUpdate):
|
||||
@@ -1002,29 +927,6 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
model = Job
|
||||
event_model = JobEvent
|
||||
|
||||
def _extract_credentials_of_kind(self, kind: str):
|
||||
return (cred for cred in self._credentials if cred.credential_type.kind == kind)
|
||||
|
||||
@property
|
||||
def _machine_credential(self) -> object:
|
||||
"""Get machine credential."""
|
||||
return next(self._extract_credentials_of_kind('ssh'), None)
|
||||
|
||||
@property
|
||||
def _vault_credentials(self) -> list[object]:
|
||||
"""Get vault credentials."""
|
||||
return list(self._extract_credentials_of_kind('vault'))
|
||||
|
||||
@property
|
||||
def _network_credentials(self) -> list[object]:
|
||||
"""Get network credentials."""
|
||||
return list(self._extract_credentials_of_kind('net'))
|
||||
|
||||
@property
|
||||
def _cloud_credentials(self) -> list[object]:
|
||||
"""Get cloud credentials."""
|
||||
return list(self._extract_credentials_of_kind('cloud'))
|
||||
|
||||
def build_private_data(self, job, private_data_dir):
|
||||
"""
|
||||
Returns a dict of the form
|
||||
@@ -1042,7 +944,7 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
}
|
||||
"""
|
||||
private_data = {'credentials': {}}
|
||||
for credential in self._credentials:
|
||||
for credential in job.credentials.prefetch_related('input_sources__source_credential').all():
|
||||
# If we were sent SSH credentials, decrypt them and send them
|
||||
# back (they will be written to a temporary file).
|
||||
if credential.has_input('ssh_key_data'):
|
||||
@@ -1058,14 +960,14 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
and ansible-vault.
|
||||
"""
|
||||
passwords = super(RunJob, self).build_passwords(job, runtime_passwords)
|
||||
cred = self._machine_credential
|
||||
cred = job.machine_credential
|
||||
if cred:
|
||||
for field in ('ssh_key_unlock', 'ssh_password', 'become_password', 'vault_password'):
|
||||
value = runtime_passwords.get(field, cred.get_input('password' if field == 'ssh_password' else field, default=''))
|
||||
if value not in ('', 'ASK'):
|
||||
passwords[field] = value
|
||||
|
||||
for cred in self._vault_credentials:
|
||||
for cred in job.vault_credentials:
|
||||
field = 'vault_password'
|
||||
vault_id = cred.get_input('vault_id', default=None)
|
||||
if vault_id:
|
||||
@@ -1081,7 +983,7 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
key unlock over network key unlock.
|
||||
'''
|
||||
if 'ssh_key_unlock' not in passwords:
|
||||
for cred in self._network_credentials:
|
||||
for cred in job.network_credentials:
|
||||
if cred.inputs.get('ssh_key_unlock'):
|
||||
passwords['ssh_key_unlock'] = runtime_passwords.get('ssh_key_unlock', cred.get_input('ssh_key_unlock', default=''))
|
||||
break
|
||||
@@ -1116,11 +1018,11 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
|
||||
# Set environment variables for cloud credentials.
|
||||
cred_files = private_data_files.get('credentials', {})
|
||||
for cloud_cred in self._cloud_credentials:
|
||||
for cloud_cred in job.cloud_credentials:
|
||||
if cloud_cred and cloud_cred.credential_type.namespace == 'openstack' and cred_files.get(cloud_cred, ''):
|
||||
env['OS_CLIENT_CONFIG_FILE'] = get_incontainer_path(cred_files.get(cloud_cred, ''), private_data_dir)
|
||||
|
||||
for network_cred in self._network_credentials:
|
||||
for network_cred in job.network_credentials:
|
||||
env['ANSIBLE_NET_USERNAME'] = network_cred.get_input('username', default='')
|
||||
env['ANSIBLE_NET_PASSWORD'] = network_cred.get_input('password', default='')
|
||||
|
||||
@@ -1163,11 +1065,6 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
if 'callbacks_enabled' in config_values:
|
||||
env['ANSIBLE_CALLBACKS_ENABLED'] += ':' + config_values['callbacks_enabled']
|
||||
|
||||
# Add vendor collections path for external query file discovery
|
||||
vendor_collections_path = os.path.join(CONTAINER_ROOT, 'vendor_collections')
|
||||
env['ANSIBLE_COLLECTIONS_PATH'] = f"{vendor_collections_path}:{env['ANSIBLE_COLLECTIONS_PATH']}"
|
||||
logger.debug(f"ANSIBLE_COLLECTIONS_PATH updated for vendor collections: {env['ANSIBLE_COLLECTIONS_PATH']}")
|
||||
|
||||
return env
|
||||
|
||||
def build_args(self, job, private_data_dir, passwords):
|
||||
@@ -1175,7 +1072,7 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
Build command line argument list for running ansible-playbook,
|
||||
optionally using ssh-agent for public/private key authentication.
|
||||
"""
|
||||
creds = self._machine_credential
|
||||
creds = job.machine_credential
|
||||
|
||||
ssh_username, become_username, become_method = '', '', ''
|
||||
if creds:
|
||||
@@ -1327,16 +1224,10 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
return
|
||||
if self.should_use_fact_cache() and self.runner_callback.artifacts_processed:
|
||||
job.log_lifecycle("finish_job_fact_cache")
|
||||
if job.inventory.kind == 'constructed':
|
||||
hosts_qs = job.get_source_hosts_for_constructed_inventory()
|
||||
else:
|
||||
hosts_qs = job.inventory.hosts
|
||||
finish_fact_cache(
|
||||
hosts_qs,
|
||||
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)),
|
||||
job_id=job.id,
|
||||
inventory_id=job.inventory_id,
|
||||
job_created=job.created,
|
||||
)
|
||||
|
||||
def final_run_hook(self, job, status, private_data_dir):
|
||||
@@ -1683,7 +1574,7 @@ class RunProjectUpdate(BaseTask):
|
||||
return params
|
||||
|
||||
def build_credentials_list(self, project_update):
|
||||
if project_update.credential:
|
||||
if project_update.scm_type == 'insights' and project_update.credential:
|
||||
return [project_update.credential]
|
||||
return []
|
||||
|
||||
@@ -1866,24 +1757,6 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
|
||||
# All credentials not used by inventory source injector
|
||||
return inventory_update.get_extra_credentials()
|
||||
|
||||
def populate_workload_identity_tokens(self, additional_credentials=None):
|
||||
"""Also generate OIDC tokens for the cloud credential.
|
||||
|
||||
The cloud credential is not in _credentials (it is handled by the
|
||||
inventory source injector), but it may still need a workload identity
|
||||
token generated for it.
|
||||
"""
|
||||
cloud_cred = self.instance.get_cloud_credential()
|
||||
creds = list(additional_credentials or [])
|
||||
if cloud_cred:
|
||||
creds.append(cloud_cred)
|
||||
super().populate_workload_identity_tokens(additional_credentials=creds or None)
|
||||
# Override get_cloud_credential on this instance so the injector
|
||||
# uses the credential with OIDC context instead of doing a fresh
|
||||
# DB fetch that would lose it.
|
||||
if cloud_cred and cloud_cred.context:
|
||||
self.instance.get_cloud_credential = lambda: cloud_cred
|
||||
|
||||
def build_project_dir(self, inventory_update, private_data_dir):
|
||||
source_project = None
|
||||
if inventory_update.inventory_source:
|
||||
|
||||
@@ -393,9 +393,9 @@ def evaluate_policy(instance):
|
||||
raise PolicyEvaluationError(_('Following certificate settings are missing for OPA_AUTH_TYPE=Certificate: {}').format(cert_settings_missing))
|
||||
|
||||
query_paths = [
|
||||
('Organization', instance.organization.opa_query_path if instance.organization else None),
|
||||
('Inventory', instance.inventory.opa_query_path if instance.inventory else None),
|
||||
('Job template', instance.job_template.opa_query_path if instance.job_template else None),
|
||||
('Organization', instance.organization.opa_query_path),
|
||||
('Inventory', instance.inventory.opa_query_path),
|
||||
('Job template', instance.job_template.opa_query_path),
|
||||
]
|
||||
violations = dict()
|
||||
errors = dict()
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
---
|
||||
authors:
|
||||
- AWX Project Contributors <awx-project@googlegroups.com>
|
||||
dependencies: {}
|
||||
description: External query testing collection. No embedded query file. Not for use in production.
|
||||
documentation: https://github.com/ansible/awx
|
||||
homepage: https://github.com/ansible/awx
|
||||
issues: https://github.com/ansible/awx
|
||||
license:
|
||||
- GPL-3.0-or-later
|
||||
name: external
|
||||
namespace: demo
|
||||
readme: README.md
|
||||
repository: https://github.com/ansible/awx
|
||||
tags:
|
||||
- demo
|
||||
- testing
|
||||
- external_query
|
||||
version: 1.0.0
|
||||
@@ -1,78 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
# Same licensing as AWX
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
__metaclass__ = type
|
||||
|
||||
DOCUMENTATION = r'''
|
||||
---
|
||||
module: example
|
||||
|
||||
short_description: Module for specific live tests
|
||||
|
||||
version_added: "2.0.0"
|
||||
|
||||
description: This module is part of a test collection in local source. Used for external query testing.
|
||||
|
||||
options:
|
||||
host_name:
|
||||
description: Name to return as the host name.
|
||||
required: false
|
||||
type: str
|
||||
|
||||
author:
|
||||
- AWX Live Tests
|
||||
'''
|
||||
|
||||
EXAMPLES = r'''
|
||||
- name: Test with defaults
|
||||
demo.external.example:
|
||||
|
||||
- name: Test with custom host name
|
||||
demo.external.example:
|
||||
host_name: foo_host
|
||||
'''
|
||||
|
||||
RETURN = r'''
|
||||
direct_host_name:
|
||||
description: The name of the host, this will be collected with the feature.
|
||||
type: str
|
||||
returned: always
|
||||
sample: 'foo_host'
|
||||
'''
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def run_module():
|
||||
module_args = dict(
|
||||
host_name=dict(type='str', required=False, default='foo_host_default'),
|
||||
)
|
||||
|
||||
result = dict(
|
||||
changed=False,
|
||||
other_data='sample_string',
|
||||
)
|
||||
|
||||
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
|
||||
|
||||
if module.check_mode:
|
||||
module.exit_json(**result)
|
||||
|
||||
result['direct_host_name'] = module.params['host_name']
|
||||
result['nested_host_name'] = {'host_name': module.params['host_name']}
|
||||
result['name'] = 'vm-foo'
|
||||
|
||||
# non-cononical facts
|
||||
result['device_type'] = 'Fake Host'
|
||||
|
||||
module.exit_json(**result)
|
||||
|
||||
|
||||
def main():
|
||||
run_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,19 +0,0 @@
|
||||
---
|
||||
authors:
|
||||
- AWX Project Contributors <awx-project@googlegroups.com>
|
||||
dependencies: {}
|
||||
description: External query testing collection v1.5.0. No embedded query file. Not for use in production.
|
||||
documentation: https://github.com/ansible/awx
|
||||
homepage: https://github.com/ansible/awx
|
||||
issues: https://github.com/ansible/awx
|
||||
license:
|
||||
- GPL-3.0-or-later
|
||||
name: external
|
||||
namespace: demo
|
||||
readme: README.md
|
||||
repository: https://github.com/ansible/awx
|
||||
tags:
|
||||
- demo
|
||||
- testing
|
||||
- external_query
|
||||
version: 1.5.0
|
||||
@@ -1,78 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
# Same licensing as AWX
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
__metaclass__ = type
|
||||
|
||||
DOCUMENTATION = r'''
|
||||
---
|
||||
module: example
|
||||
|
||||
short_description: Module for specific live tests
|
||||
|
||||
version_added: "2.0.0"
|
||||
|
||||
description: This module is part of a test collection in local source. Used for external query testing.
|
||||
|
||||
options:
|
||||
host_name:
|
||||
description: Name to return as the host name.
|
||||
required: false
|
||||
type: str
|
||||
|
||||
author:
|
||||
- AWX Live Tests
|
||||
'''
|
||||
|
||||
EXAMPLES = r'''
|
||||
- name: Test with defaults
|
||||
demo.external.example:
|
||||
|
||||
- name: Test with custom host name
|
||||
demo.external.example:
|
||||
host_name: foo_host
|
||||
'''
|
||||
|
||||
RETURN = r'''
|
||||
direct_host_name:
|
||||
description: The name of the host, this will be collected with the feature.
|
||||
type: str
|
||||
returned: always
|
||||
sample: 'foo_host'
|
||||
'''
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def run_module():
|
||||
module_args = dict(
|
||||
host_name=dict(type='str', required=False, default='foo_host_default'),
|
||||
)
|
||||
|
||||
result = dict(
|
||||
changed=False,
|
||||
other_data='sample_string',
|
||||
)
|
||||
|
||||
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
|
||||
|
||||
if module.check_mode:
|
||||
module.exit_json(**result)
|
||||
|
||||
result['direct_host_name'] = module.params['host_name']
|
||||
result['nested_host_name'] = {'host_name': module.params['host_name']}
|
||||
result['name'] = 'vm-foo'
|
||||
|
||||
# non-cononical facts
|
||||
result['device_type'] = 'Fake Host'
|
||||
|
||||
module.exit_json(**result)
|
||||
|
||||
|
||||
def main():
|
||||
run_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,19 +0,0 @@
|
||||
---
|
||||
authors:
|
||||
- AWX Project Contributors <awx-project@googlegroups.com>
|
||||
dependencies: {}
|
||||
description: External query testing collection v3.0.0. No embedded query file. Not for use in production.
|
||||
documentation: https://github.com/ansible/awx
|
||||
homepage: https://github.com/ansible/awx
|
||||
issues: https://github.com/ansible/awx
|
||||
license:
|
||||
- GPL-3.0-or-later
|
||||
name: external
|
||||
namespace: demo
|
||||
readme: README.md
|
||||
repository: https://github.com/ansible/awx
|
||||
tags:
|
||||
- demo
|
||||
- testing
|
||||
- external_query
|
||||
version: 3.0.0
|
||||
@@ -1,78 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
# Same licensing as AWX
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
__metaclass__ = type
|
||||
|
||||
DOCUMENTATION = r'''
|
||||
---
|
||||
module: example
|
||||
|
||||
short_description: Module for specific live tests
|
||||
|
||||
version_added: "2.0.0"
|
||||
|
||||
description: This module is part of a test collection in local source. Used for external query testing.
|
||||
|
||||
options:
|
||||
host_name:
|
||||
description: Name to return as the host name.
|
||||
required: false
|
||||
type: str
|
||||
|
||||
author:
|
||||
- AWX Live Tests
|
||||
'''
|
||||
|
||||
EXAMPLES = r'''
|
||||
- name: Test with defaults
|
||||
demo.external.example:
|
||||
|
||||
- name: Test with custom host name
|
||||
demo.external.example:
|
||||
host_name: foo_host
|
||||
'''
|
||||
|
||||
RETURN = r'''
|
||||
direct_host_name:
|
||||
description: The name of the host, this will be collected with the feature.
|
||||
type: str
|
||||
returned: always
|
||||
sample: 'foo_host'
|
||||
'''
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def run_module():
|
||||
module_args = dict(
|
||||
host_name=dict(type='str', required=False, default='foo_host_default'),
|
||||
)
|
||||
|
||||
result = dict(
|
||||
changed=False,
|
||||
other_data='sample_string',
|
||||
)
|
||||
|
||||
module = AnsibleModule(argument_spec=module_args, supports_check_mode=True)
|
||||
|
||||
if module.check_mode:
|
||||
module.exit_json(**result)
|
||||
|
||||
result['direct_host_name'] = module.params['host_name']
|
||||
result['nested_host_name'] = {'host_name': module.params['host_name']}
|
||||
result['name'] = 'vm-foo'
|
||||
|
||||
# non-cononical facts
|
||||
result['device_type'] = 'Fake Host'
|
||||
|
||||
module.exit_json(**result)
|
||||
|
||||
|
||||
def main():
|
||||
run_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,11 +0,0 @@
|
||||
---
|
||||
- hosts: all
|
||||
gather_facts: false
|
||||
connection: local
|
||||
tasks:
|
||||
- name: Set artifacts via set_stats
|
||||
ansible.builtin.set_stats:
|
||||
data: "{{ stats_data }}"
|
||||
per_host: false
|
||||
aggregate: false
|
||||
when: stats_data is defined
|
||||
@@ -1,21 +0,0 @@
|
||||
---
|
||||
# Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
|
||||
- hosts: all
|
||||
vars:
|
||||
extra_value: ""
|
||||
gather_facts: false
|
||||
connection: local
|
||||
tasks:
|
||||
- name: set a custom fact
|
||||
set_fact:
|
||||
foo: "bar{{ extra_value }}"
|
||||
bar:
|
||||
a:
|
||||
b:
|
||||
- "c"
|
||||
- "d"
|
||||
cacheable: true
|
||||
- name: sleep to create overlap window for concurrent job testing
|
||||
wait_for:
|
||||
timeout: 2
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
collections:
|
||||
- name: 'file:///tmp/live_tests/host_query_external_v1_0_0'
|
||||
type: git
|
||||
version: devel
|
||||
@@ -1,8 +0,0 @@
|
||||
---
|
||||
- hosts: all
|
||||
gather_facts: false
|
||||
connection: local
|
||||
tasks:
|
||||
- demo.external.example:
|
||||
register: result
|
||||
- debug: var=result
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
collections:
|
||||
- name: 'file:///tmp/live_tests/host_query_external_v1_5_0'
|
||||
type: git
|
||||
version: devel
|
||||
@@ -1,8 +0,0 @@
|
||||
---
|
||||
- hosts: all
|
||||
gather_facts: false
|
||||
connection: local
|
||||
tasks:
|
||||
- demo.external.example:
|
||||
register: result
|
||||
- debug: var=result
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
collections:
|
||||
- name: 'file:///tmp/live_tests/host_query_external_v3_0_0'
|
||||
type: git
|
||||
version: devel
|
||||
@@ -1,8 +0,0 @@
|
||||
---
|
||||
- hosts: all
|
||||
gather_facts: false
|
||||
connection: local
|
||||
tasks:
|
||||
- demo.external.example:
|
||||
register: result
|
||||
- debug: var=result
|
||||
@@ -1,84 +0,0 @@
|
||||
import pytest
|
||||
from awx.api.versioning import reverse
|
||||
from rest_framework import status
|
||||
|
||||
from awx.main.models.jobs import JobTemplate
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestConfigEndpointFields:
|
||||
def test_base_fields_all_users(self, get, rando):
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, rando, expect=200)
|
||||
|
||||
assert 'time_zone' in response.data
|
||||
assert 'license_info' in response.data
|
||||
assert 'version' in response.data
|
||||
assert 'eula' in response.data
|
||||
assert 'analytics_status' in response.data
|
||||
assert 'analytics_collectors' in response.data
|
||||
assert 'become_methods' in response.data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"role_type",
|
||||
[
|
||||
"superuser",
|
||||
"system_auditor",
|
||||
"org_admin",
|
||||
"org_auditor",
|
||||
"org_project_admin",
|
||||
],
|
||||
)
|
||||
def test_privileged_users_conditional_fields(self, get, user, organization, admin, role_type):
|
||||
url = reverse('api:api_v2_config_view')
|
||||
|
||||
if role_type == "superuser":
|
||||
test_user = admin
|
||||
elif role_type == "system_auditor":
|
||||
test_user = user('system-auditor', is_superuser=False)
|
||||
test_user.is_system_auditor = True
|
||||
test_user.save()
|
||||
elif role_type == "org_admin":
|
||||
test_user = user('org-admin', is_superuser=False)
|
||||
organization.admin_role.members.add(test_user)
|
||||
elif role_type == "org_auditor":
|
||||
test_user = user('org-auditor', is_superuser=False)
|
||||
organization.auditor_role.members.add(test_user)
|
||||
elif role_type == "org_project_admin":
|
||||
test_user = user('org-project-admin', is_superuser=False)
|
||||
organization.project_admin_role.members.add(test_user)
|
||||
|
||||
response = get(url, test_user, expect=200)
|
||||
|
||||
assert 'project_base_dir' in response.data
|
||||
assert 'project_local_paths' in response.data
|
||||
assert 'custom_virtualenvs' in response.data
|
||||
|
||||
def test_job_template_admin_gets_venvs_only(self, get, user, organization, project, inventory):
|
||||
"""Test that JobTemplate admin without org access gets only custom_virtualenvs"""
|
||||
jt_admin = user('jt-admin', is_superuser=False)
|
||||
|
||||
jt = JobTemplate.objects.create(name='test-jt', organization=organization, project=project, inventory=inventory)
|
||||
jt.admin_role.members.add(jt_admin)
|
||||
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, jt_admin, expect=200)
|
||||
|
||||
assert 'custom_virtualenvs' in response.data
|
||||
assert 'project_base_dir' not in response.data
|
||||
assert 'project_local_paths' not in response.data
|
||||
|
||||
def test_normal_user_no_conditional_fields(self, get, rando):
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, rando, expect=200)
|
||||
|
||||
assert 'project_base_dir' not in response.data
|
||||
assert 'project_local_paths' not in response.data
|
||||
assert 'custom_virtualenvs' not in response.data
|
||||
|
||||
def test_unauthenticated_denied(self, get):
|
||||
"""Test that unauthenticated requests are denied"""
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, None, expect=401)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -1,7 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from ansible_base.lib.testing.util import feature_flag_enabled, feature_flag_disabled
|
||||
|
||||
from awx.main.models import CredentialInputSource
|
||||
from awx.api.versioning import reverse
|
||||
|
||||
@@ -318,60 +316,3 @@ def test_create_credential_input_source_with_already_used_input_returns_400(post
|
||||
]
|
||||
all_responses = [post(list_url, params, admin) for params in all_params]
|
||||
assert all_responses.pop().status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_input_source_passes_workload_identity_token_when_flag_enabled(vault_credential, external_credential, mocker):
|
||||
"""Test that workload_identity_token is passed to backend when flag is enabled."""
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
# Add workload_identity_token as an internal field on the external credential type
|
||||
# so get_input_value resolves it from the per-input-source context
|
||||
external_credential.credential_type.inputs['fields'].append(
|
||||
{'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'internal': True}
|
||||
)
|
||||
|
||||
# Create an input source
|
||||
input_source = CredentialInputSource.objects.create(
|
||||
target_credential=vault_credential,
|
||||
source_credential=external_credential,
|
||||
input_field_name='vault_password',
|
||||
metadata={'key': 'test_key'},
|
||||
)
|
||||
|
||||
# Mock the credential plugin backend
|
||||
mock_backend = mocker.patch.object(external_credential.credential_type.plugin, 'backend', autospec=True, return_value='test_value')
|
||||
|
||||
# Call with context keyed by input source PK
|
||||
test_context = {input_source.pk: {'workload_identity_token': 'jwt_token_here'}}
|
||||
result = input_source.get_input_value(context=test_context)
|
||||
|
||||
# Verify backend was called with workload_identity_token
|
||||
assert result == 'test_value'
|
||||
call_kwargs = mock_backend.call_args[1]
|
||||
assert call_kwargs['workload_identity_token'] == 'jwt_token_here'
|
||||
assert call_kwargs['key'] == 'test_key'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_input_source_skips_workload_identity_token_when_flag_disabled(vault_credential, external_credential, mocker):
|
||||
"""Test that workload_identity_token is NOT passed when flag is disabled."""
|
||||
with feature_flag_disabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
# Create an input source
|
||||
input_source = CredentialInputSource.objects.create(
|
||||
target_credential=vault_credential,
|
||||
source_credential=external_credential,
|
||||
input_field_name='vault_password',
|
||||
metadata={'key': 'test_key'},
|
||||
)
|
||||
# Mock the credential plugin backend
|
||||
mock_backend = mocker.patch.object(external_credential.credential_type.plugin, 'backend', autospec=True, return_value='test_value')
|
||||
# Call with context containing workload_identity_token but NO internal field defined,
|
||||
# simulating a flag-disabled scenario where tokens are not generated upstream
|
||||
test_context = {input_source.pk: {'workload_identity_token': 'jwt_token_here'}}
|
||||
result = input_source.get_input_value(context=test_context)
|
||||
# Verify backend was called WITHOUT workload_identity_token since the credential type
|
||||
# does not define it as an internal field (flag-disabled path doesn't register it)
|
||||
assert result == 'test_value'
|
||||
call_kwargs = mock_backend.call_args[1]
|
||||
assert 'workload_identity_token' not in call_kwargs
|
||||
assert call_kwargs['key'] == 'test_key'
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from ansible_base.lib.testing.util import feature_flag_enabled
|
||||
from awx.main.models.credential import CredentialType, Credential
|
||||
from awx.api.versioning import reverse
|
||||
|
||||
@@ -160,8 +159,7 @@ def test_create_as_admin(get, post, admin):
|
||||
response = get(reverse('api:credential_type_list'), admin)
|
||||
assert response.data['count'] == 1
|
||||
assert response.data['results'][0]['name'] == 'Custom Credential Type'
|
||||
# Serializer normalizes empty inputs to {'fields': []}
|
||||
assert response.data['results'][0]['inputs'] == {'fields': []}
|
||||
assert response.data['results'][0]['inputs'] == {}
|
||||
assert response.data['results'][0]['injectors'] == {}
|
||||
assert response.data['results'][0]['managed'] is False
|
||||
|
||||
@@ -476,98 +474,3 @@ def test_credential_type_rbac_external_test(post, alice, admin, credentialtype_e
|
||||
data = {'inputs': {}, 'metadata': {}}
|
||||
assert post(url, data, admin).status_code == 202
|
||||
assert post(url, data, alice).status_code == 403
|
||||
|
||||
|
||||
# --- Tests for internal field filtering with None/invalid inputs ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_with_none_inputs(get, admin):
|
||||
"""Test that credential type with empty inputs dict works correctly."""
|
||||
# Create a credential type with empty dict
|
||||
ct = CredentialType.objects.create(
|
||||
kind='cloud',
|
||||
name='Test Type',
|
||||
managed=False,
|
||||
inputs={}, # Empty dict, not None (DB has NOT NULL constraint)
|
||||
)
|
||||
|
||||
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
# Should have normalized inputs to empty dict
|
||||
assert 'inputs' in response.data
|
||||
assert isinstance(response.data['inputs'], dict)
|
||||
assert response.data['inputs']['fields'] == []
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_with_invalid_inputs_type(get, admin):
|
||||
"""Test that credential type with non-dict inputs doesn't cause errors."""
|
||||
# Create a credential type with invalid inputs type
|
||||
ct = CredentialType.objects.create(kind='cloud', name='Test Type', managed=False, inputs={'fields': 'not-a-list'})
|
||||
|
||||
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
# Should gracefully handle invalid fields type
|
||||
assert 'inputs' in response.data
|
||||
assert response.data['inputs']['fields'] == []
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_filters_internal_fields(get, admin):
|
||||
"""Test that internal fields are filtered from API responses."""
|
||||
ct = CredentialType.objects.create(
|
||||
kind='cloud',
|
||||
name='Test OIDC Type',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'url', 'label': 'URL', 'type': 'string'},
|
||||
{'id': 'token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True},
|
||||
{'id': 'public_field', 'label': 'Public', 'type': 'string'},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
|
||||
field_ids = [f['id'] for f in response.data['inputs']['fields']]
|
||||
# Internal field should be filtered out
|
||||
assert 'token' not in field_ids
|
||||
assert 'url' in field_ids
|
||||
assert 'public_field' in field_ids
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_list_filters_internal_fields(get, admin):
|
||||
"""Test that internal fields are filtered in list view."""
|
||||
CredentialType.objects.create(
|
||||
kind='cloud',
|
||||
name='Test OIDC Type',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'url', 'label': 'URL', 'type': 'string'},
|
||||
{'id': 'workload_identity_token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
url = reverse('api:credential_type_list')
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Find our credential type in the results
|
||||
test_ct = next((ct for ct in response.data['results'] if ct['name'] == 'Test OIDC Type'), None)
|
||||
assert test_ct is not None
|
||||
|
||||
field_ids = [f['id'] for f in test_ct['inputs']['fields']]
|
||||
# Internal field should be filtered out
|
||||
assert 'workload_identity_token' not in field_ids
|
||||
assert 'url' in field_ids
|
||||
|
||||
@@ -485,3 +485,47 @@ class TestJobTemplateCallbackProxyIntegration:
|
||||
expect=400,
|
||||
**headers
|
||||
)
|
||||
|
||||
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[])
|
||||
def test_only_first_entry_in_comma_separated_header_is_considered(self, job_template, admin_user, post):
|
||||
"""
|
||||
Test that only the first entry in a comma-separated header value is used for host matching.
|
||||
This is important for X-Forwarded-For style headers where the format is "client, proxy1, proxy2".
|
||||
Only the original client (first entry) should be matched against inventory hosts.
|
||||
"""
|
||||
# Create host that matches the SECOND entry in the comma-separated list
|
||||
job_template.inventory.hosts.create(name='second-host.example.com')
|
||||
|
||||
headers = {
|
||||
# First entry is 'first-host.example.com', second is 'second-host.example.com'
|
||||
# Only the first should be considered, so this should NOT match
|
||||
'HTTP_X_FROM_THE_LOAD_BALANCER': 'first-host.example.com, second-host.example.com',
|
||||
'REMOTE_ADDR': 'unrelated-addr',
|
||||
'REMOTE_HOST': 'unrelated-host',
|
||||
}
|
||||
|
||||
# Should return 400 because only 'first-host.example.com' is considered,
|
||||
# and that host is NOT in the inventory
|
||||
r = post(
|
||||
url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=400, **headers
|
||||
)
|
||||
assert r.data['msg'] == 'No matching host could be found!'
|
||||
|
||||
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[])
|
||||
def test_first_entry_in_comma_separated_header_matches(self, job_template, admin_user, post):
|
||||
"""
|
||||
Test that the first entry in a comma-separated header value correctly matches an inventory host.
|
||||
"""
|
||||
# Create host that matches the FIRST entry in the comma-separated list
|
||||
job_template.inventory.hosts.create(name='first-host.example.com')
|
||||
|
||||
headers = {
|
||||
# First entry is 'first-host.example.com', second is 'second-host.example.com'
|
||||
# The first entry matches the inventory host
|
||||
'HTTP_X_FROM_THE_LOAD_BALANCER': 'first-host.example.com, second-host.example.com',
|
||||
'REMOTE_ADDR': 'unrelated-addr',
|
||||
'REMOTE_HOST': 'unrelated-host',
|
||||
}
|
||||
|
||||
# Should return 201 because 'first-host.example.com' is the first entry and matches
|
||||
post(url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=201, **headers)
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Tests for OIDC workload identity credential type feature flag.
|
||||
|
||||
The FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED flag is an install-time flag that
|
||||
controls whether OIDC credential types are loaded into the registry at startup.
|
||||
When disabled, OIDC credential types are not loaded and do not exist in the database.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES
|
||||
from awx.main.models.credential import CredentialType, ManagedCredentialType, load_credentials
|
||||
from awx.api.versioning import reverse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reload_credentials_with_flag(django_db_setup, django_db_blocker):
|
||||
"""
|
||||
Fixture that reloads credentials with a specific flag state.
|
||||
This simulates what happens at application startup.
|
||||
"""
|
||||
# Save original registry state
|
||||
original_registry = ManagedCredentialType.registry.copy()
|
||||
|
||||
def _reload(flag_enabled):
|
||||
with django_db_blocker.unblock():
|
||||
# Clear the entire registry before reloading
|
||||
ManagedCredentialType.registry.clear()
|
||||
|
||||
# Reload credentials with the specified flag state
|
||||
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=flag_enabled):
|
||||
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
|
||||
load_credentials()
|
||||
|
||||
# Sync to database
|
||||
CredentialType.setup_tower_managed_defaults(lock=False)
|
||||
|
||||
# In tests, the session fixture pre-loads all credential types into the DB.
|
||||
# Remove OIDC types when testing the disabled state so the API test is accurate.
|
||||
if not flag_enabled:
|
||||
CredentialType.objects.filter(namespace__in=OIDC_CREDENTIAL_TYPE_NAMESPACES).delete()
|
||||
|
||||
yield _reload
|
||||
|
||||
# Restore original registry state after tests
|
||||
ManagedCredentialType.registry.clear()
|
||||
ManagedCredentialType.registry.update(original_registry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_registry():
|
||||
"""Save and restore the ManagedCredentialType registry, with full isolation via mocked entry_points."""
|
||||
original_registry = ManagedCredentialType.registry.copy()
|
||||
ManagedCredentialType.registry.clear()
|
||||
yield
|
||||
ManagedCredentialType.registry.clear()
|
||||
ManagedCredentialType.registry.update(original_registry)
|
||||
|
||||
|
||||
def _make_mock_entry_point(name):
|
||||
"""Create a mock entry point that mimics a credential plugin."""
|
||||
ep = mock.MagicMock()
|
||||
ep.name = name
|
||||
ep.value = f'test_plugin:{name}'
|
||||
plugin = mock.MagicMock(spec=[])
|
||||
ep.load.return_value = plugin
|
||||
return ep
|
||||
|
||||
|
||||
def _mock_entry_points_factory(managed_names, supported_names):
|
||||
"""Return a side_effect function for mocking entry_points() with controlled plugins."""
|
||||
managed = [_make_mock_entry_point(n) for n in managed_names]
|
||||
supported = [_make_mock_entry_point(n) for n in supported_names]
|
||||
|
||||
def _entry_points(group):
|
||||
if group == 'awx_plugins.managed_credentials':
|
||||
return managed
|
||||
elif group == 'awx_plugins.managed_credentials.supported':
|
||||
return supported
|
||||
return []
|
||||
|
||||
return _entry_points
|
||||
|
||||
|
||||
# --- Unit tests for load_credentials() registry behavior ---
|
||||
|
||||
|
||||
def test_oidc_types_in_registry_when_flag_enabled(isolated_registry):
|
||||
"""Test that OIDC credential types are added to the registry when flag is enabled."""
|
||||
mock_eps = _mock_entry_points_factory(
|
||||
managed_names=['ssh', 'vault'],
|
||||
supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'],
|
||||
)
|
||||
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True):
|
||||
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
|
||||
with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps):
|
||||
load_credentials()
|
||||
|
||||
for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES:
|
||||
assert ns in ManagedCredentialType.registry, f"{ns} should be in registry when flag is enabled"
|
||||
assert 'ssh' in ManagedCredentialType.registry
|
||||
assert 'vault' in ManagedCredentialType.registry
|
||||
|
||||
|
||||
def test_oidc_types_not_in_registry_when_flag_disabled(isolated_registry):
|
||||
"""Test that OIDC credential types are excluded from the registry when flag is disabled."""
|
||||
mock_eps = _mock_entry_points_factory(
|
||||
managed_names=['ssh', 'vault'],
|
||||
supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'],
|
||||
)
|
||||
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False):
|
||||
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
|
||||
with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps):
|
||||
load_credentials()
|
||||
|
||||
for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES:
|
||||
assert ns not in ManagedCredentialType.registry, f"{ns} should not be in registry when flag is disabled"
|
||||
# Non-OIDC types should still be loaded
|
||||
assert 'ssh' in ManagedCredentialType.registry
|
||||
assert 'vault' in ManagedCredentialType.registry
|
||||
|
||||
|
||||
def test_oidc_namespaces_constant():
|
||||
"""Test that OIDC_CREDENTIAL_TYPE_NAMESPACES contains the expected namespaces."""
|
||||
assert 'hashivault-kv-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES
|
||||
assert 'hashivault-ssh-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES
|
||||
assert len(OIDC_CREDENTIAL_TYPE_NAMESPACES) == 2
|
||||
|
||||
|
||||
# --- Functional API tests ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_oidc_types_loaded_when_flag_enabled(get, admin, reload_credentials_with_flag):
|
||||
"""Test that OIDC credential types are visible in the API when flag is enabled."""
|
||||
reload_credentials_with_flag(flag_enabled=True)
|
||||
|
||||
response = get(reverse('api:credential_type_list'), admin)
|
||||
assert response.status_code == 200
|
||||
|
||||
namespaces = [ct['namespace'] for ct in response.data['results']]
|
||||
assert 'hashivault-kv-oidc' in namespaces
|
||||
assert 'hashivault-ssh-oidc' in namespaces
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_oidc_types_not_loaded_when_flag_disabled(get, admin, reload_credentials_with_flag):
|
||||
"""Test that OIDC credential types are not visible in the API when flag is disabled."""
|
||||
reload_credentials_with_flag(flag_enabled=False)
|
||||
|
||||
response = get(reverse('api:credential_type_list'), admin)
|
||||
assert response.status_code == 200
|
||||
|
||||
namespaces = [ct['namespace'] for ct in response.data['results']]
|
||||
assert 'hashivault-kv-oidc' not in namespaces
|
||||
assert 'hashivault-ssh-oidc' not in namespaces
|
||||
|
||||
# Verify they're also not in the database
|
||||
assert not CredentialType.objects.filter(namespace='hashivault-kv-oidc').exists()
|
||||
assert not CredentialType.objects.filter(namespace='hashivault-ssh-oidc').exists()
|
||||
@@ -1,312 +0,0 @@
|
||||
"""
|
||||
Tests for OIDC workload identity credential test endpoints.
|
||||
|
||||
Tests the /api/v2/credentials/<id>/test/ and /api/v2/credential_types/<id>/test/
|
||||
endpoints when used with OIDC-enabled credential types.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
from awx.main.models import Credential, CredentialType, JobTemplate
|
||||
from awx.api.versioning import reverse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_template(organization, project):
|
||||
"""Job template with organization and project for OIDC JWT generation."""
|
||||
return JobTemplate.objects.create(name='test-jt', organization=organization, project=project, playbook='helloworld.yml')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_credentialtype():
|
||||
"""Create a credential type with workload_identity_token internal field."""
|
||||
oidc_type_inputs = {
|
||||
'fields': [
|
||||
{'id': 'url', 'label': 'Vault URL', 'type': 'string', 'help_text': 'The Vault server URL.'},
|
||||
{'id': 'auth_path', 'label': 'Auth Path', 'type': 'string', 'help_text': 'JWT auth mount path.'},
|
||||
{'id': 'role_id', 'label': 'Role ID', 'type': 'string', 'help_text': 'Vault role.'},
|
||||
{'id': 'jwt_aud', 'label': 'JWT Audience', 'type': 'string', 'help_text': 'Expected audience.'},
|
||||
{'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'secret': True, 'internal': True},
|
||||
],
|
||||
'metadata': [
|
||||
{'id': 'secret_path', 'label': 'Secret Path', 'type': 'string'},
|
||||
{'id': 'job_template_id', 'label': 'Job Template ID', 'type': 'string'},
|
||||
],
|
||||
'required': ['url', 'auth_path', 'role_id'],
|
||||
}
|
||||
|
||||
class MockPlugin(object):
|
||||
def backend(self, **kwargs):
|
||||
# Simulate successful backend call
|
||||
return 'secret'
|
||||
|
||||
with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin:
|
||||
mock_plugin.return_value = MockPlugin()
|
||||
oidc_type = CredentialType(kind='external', managed=True, namespace='hashivault-kv-oidc', name='HashiCorp Vault KV (OIDC)', inputs=oidc_type_inputs)
|
||||
oidc_type.save()
|
||||
yield oidc_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_credential(oidc_credentialtype):
|
||||
"""Create a credential using the OIDC credential type."""
|
||||
return Credential.objects.create(
|
||||
credential_type=oidc_credentialtype,
|
||||
name='oidc-vault-cred',
|
||||
inputs={'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oidc_backend():
|
||||
"""Fixture that mocks OIDC JWT generation and credential backend."""
|
||||
with mock.patch('awx.api.views.retrieve_workload_identity_jwt_with_claims') as mock_jwt, mock.patch('awx.api.views._jwt_decode') as mock_decode, mock.patch(
|
||||
'awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock
|
||||
) as mock_plugin:
|
||||
|
||||
# Set default return values
|
||||
mock_jwt.return_value = 'fake.jwt.token'
|
||||
mock_decode.return_value = {'iss': 'http://gateway/o', 'aud': 'vault'}
|
||||
|
||||
# Create mock backend
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_backend.backend.return_value = 'secret'
|
||||
mock_plugin.return_value = mock_backend
|
||||
|
||||
# Yield all mocks for test customization
|
||||
yield {
|
||||
'jwt': mock_jwt,
|
||||
'decode': mock_decode,
|
||||
'plugin': mock_plugin,
|
||||
'backend': mock_backend,
|
||||
}
|
||||
|
||||
|
||||
# --- Tests for CredentialExternalTest endpoint ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False)
|
||||
def test_credential_test_without_oidc_feature_flag(post, admin, oidc_credential):
|
||||
"""Test that credential test works without OIDC feature flag enabled."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': '1'}}
|
||||
|
||||
with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin:
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_backend.backend.return_value = 'secret'
|
||||
mock_plugin.return_value = mock_backend
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
# Should not contain JWT payload when feature flag is disabled
|
||||
assert 'details' not in response.data or 'sent_jwt_payload' not in response.data.get('details', {})
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
@pytest.mark.parametrize(
|
||||
'job_template_id, expected_error',
|
||||
[
|
||||
(None, 'Job template ID is required'),
|
||||
('not-an-integer', 'must be an integer'),
|
||||
('99999', 'does not exist'),
|
||||
],
|
||||
ids=['missing_job_template_id', 'invalid_job_template_id_type', 'nonexistent_job_template_id'],
|
||||
)
|
||||
def test_credential_test_job_template_validation(mock_flag, post, admin, oidc_credential, job_template_id, expected_error):
|
||||
"""Test that invalid job_template_id values return 400 with appropriate error messages."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret'}}
|
||||
if job_template_id is not None:
|
||||
data['metadata']['job_template_id'] = job_template_id
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
assert 'error_message' in response.data['details']
|
||||
assert expected_error in response.data['details']['error_message']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_no_access_to_job_template(mock_flag, post, alice, oidc_credential, job_template):
|
||||
"""Test that user without access to job template gets 403."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
# Give alice use permission on credential but not on job template
|
||||
oidc_credential.use_role.members.add(alice)
|
||||
|
||||
response = post(url, data, alice)
|
||||
assert response.status_code == 403
|
||||
assert 'You do not have access to job template' in str(response.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""Test that successful test returns JWT payload in response."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
# Customize mock for this test
|
||||
mock_oidc_backend['decode'].return_value = {
|
||||
'iss': 'http://gateway/o',
|
||||
'sub': 'system:serviceaccount:default:awx-operator',
|
||||
'aud': 'vault',
|
||||
'job_template_id': job_template.id,
|
||||
}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert response.data['details']['sent_jwt_payload']['job_template_id'] == job_template.id
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_response_does_not_contain_secret_value(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""
|
||||
the OIDC credential test endpoint must not echo the resolved Vault secret back to the caller.
|
||||
"""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
credential_secret_value = 'CREDENTIAL_SECRET'
|
||||
mock_oidc_backend['backend'].backend.return_value = credential_secret_value
|
||||
|
||||
response = post(url, data, admin)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert 'secret_value' not in response.data['details']
|
||||
assert credential_secret_value not in str(response.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_backend_failure_returns_jwt_and_error(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""Test that backend failure still returns JWT payload along with error message."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
# Make backend fail
|
||||
mock_oidc_backend['backend'].backend.side_effect = RuntimeError('Connection failed')
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
# Both JWT payload and error message should be present
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert 'error_message' in response.data['details']
|
||||
assert 'Connection failed' in response.data['details']['error_message']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_jwt_generation_failure(mock_flag, post, admin, oidc_credential, job_template):
|
||||
"""Test that JWT generation failure returns error without JWT payload."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
with mock.patch('awx.api.views.OIDCCredentialTestMixin._get_workload_identity_token') as mock_jwt:
|
||||
mock_jwt.side_effect = RuntimeError('Failed to generate JWT')
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
assert 'error_message' in response.data['details']
|
||||
assert 'Failed to generate JWT' in response.data['details']['error_message']
|
||||
# No JWT payload when generation fails
|
||||
assert 'sent_jwt_payload' not in response.data['details']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_job_template_id_not_passed_to_backend(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""Test that job_template_id and jwt_aud are removed from backend_kwargs."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
|
||||
# Check that backend was called without job_template_id or jwt_aud
|
||||
call_kwargs = mock_oidc_backend['backend'].backend.call_args[1]
|
||||
assert 'job_template_id' not in call_kwargs
|
||||
assert 'jwt_aud' not in call_kwargs
|
||||
assert 'workload_identity_token' in call_kwargs
|
||||
|
||||
|
||||
# --- Tests for CredentialTypeExternalTest endpoint ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_type_test_response_does_not_contain_secret_value(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend):
|
||||
"""
|
||||
the credential-type variant of the test endpoint should not return the secret value
|
||||
"""
|
||||
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
|
||||
data = {
|
||||
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'},
|
||||
'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)},
|
||||
}
|
||||
|
||||
credential_type_seret_value = 'CREDENTIAL_TYPE_SECRET'
|
||||
mock_oidc_backend['backend'].backend.return_value = credential_type_seret_value
|
||||
response = post(url, data, admin)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert 'secret_value' not in response.data['details']
|
||||
assert credential_type_seret_value not in str(response.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_type_test_missing_job_template_id(mock_flag, post, admin, oidc_credentialtype):
|
||||
"""Test that missing job_template_id returns 400 for credential type test endpoint."""
|
||||
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
|
||||
data = {
|
||||
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'},
|
||||
'metadata': {'secret_path': 'test/secret'},
|
||||
}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
assert 'error_message' in response.data['details']
|
||||
assert 'Job template ID is required' in response.data['details']['error_message']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_type_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend):
|
||||
"""Test that successful credential type test returns JWT payload."""
|
||||
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
|
||||
data = {
|
||||
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'},
|
||||
'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)},
|
||||
}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_external_test_returns_400_for_non_external_credential(post, admin, credential):
|
||||
# credential fixture creates a non-external credential (e.g. SSH/vault kind)
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': credential.pk})
|
||||
response = post(url, {'metadata': {}}, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'not testable' in response.data.get('detail', '').lower()
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import date
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@@ -252,7 +253,7 @@ def test_user_verify_attribute_created(admin, get):
|
||||
resp = get(reverse('api:user_detail', kwargs={'pk': admin.pk}), admin)
|
||||
assert resp.data['created'] == admin.date_joined
|
||||
|
||||
past = "2020-01-01T00:00:00Z"
|
||||
past = date(2020, 1, 1).isoformat()
|
||||
for op, count in (('gt', 1), ('lt', 0)):
|
||||
resp = get(reverse('api:user_list') + f'?created__{op}={past}', admin)
|
||||
assert resp.data['count'] == count
|
||||
|
||||
@@ -13,7 +13,6 @@ from awx.main.models.workflow import (
|
||||
WorkflowJobTemplateNode,
|
||||
)
|
||||
from awx.main.models.credential import Credential
|
||||
from awx.main.models.label import Label
|
||||
from awx.main.scheduler import TaskManager, WorkflowManager, DependencyManager
|
||||
|
||||
# Django
|
||||
@@ -52,31 +51,6 @@ def test_node_accepts_prompted_fields(inventory, project, workflow_job_template,
|
||||
post(url, {'unified_job_template': job_template.pk, 'limit': 'webservers'}, user=admin_user, expect=201)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_node_extra_data_patch_with_unprompted_labels(inventory, project, organization, workflow_job_template, patch, admin_user):
|
||||
"""AAP-41742: PATCH extra_data on a workflow node should succeed even when
|
||||
the node has labels associated but the JT has ask_labels_on_launch=False."""
|
||||
jt = JobTemplate.objects.create(
|
||||
inventory=inventory,
|
||||
project=project,
|
||||
playbook='helloworld.yml',
|
||||
ask_variables_on_launch=True,
|
||||
ask_labels_on_launch=False,
|
||||
)
|
||||
label = Label.objects.create(name='repro-label', organization=organization)
|
||||
|
||||
node = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=workflow_job_template,
|
||||
unified_job_template=jt,
|
||||
extra_data={'foo': 'bar'},
|
||||
)
|
||||
node.labels.add(label)
|
||||
|
||||
url = reverse('api:workflow_job_template_node_detail', kwargs={'pk': node.pk})
|
||||
r = patch(url, {'extra_data': {'foo': 'edited'}}, user=admin_user, expect=200)
|
||||
assert r.data['extra_data'] == {'foo': 'edited'}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"field_name, field_value",
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
|
||||
worker = CallbackBrokerWorker()
|
||||
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
|
||||
worker.buff = {InventoryUpdateEvent: events}
|
||||
worker.flush(force=True)
|
||||
worker.flush()
|
||||
assert worker.buff.get(InventoryUpdateEvent, []) == []
|
||||
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
|
||||
|
||||
@@ -61,7 +61,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
|
||||
InventoryUpdateEvent(uuid=str(uuid4()), stdout='good2', **kwargs),
|
||||
]
|
||||
worker.buff = {InventoryUpdateEvent: events.copy()}
|
||||
worker.flush(force=True)
|
||||
worker.flush()
|
||||
assert InventoryUpdateEvent.objects.filter(uuid=events[0].uuid).count() == 1
|
||||
assert InventoryUpdateEvent.objects.filter(uuid=events[1].uuid).count() == 0
|
||||
assert InventoryUpdateEvent.objects.filter(uuid=events[2].uuid).count() == 1
|
||||
@@ -71,7 +71,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
|
||||
worker = CallbackBrokerWorker()
|
||||
events = [InventoryUpdateEvent(uuid=str(uuid4()), **self.event_create_kwargs())]
|
||||
worker.buff = {InventoryUpdateEvent: events.copy()}
|
||||
worker.flush(force=True)
|
||||
worker.flush()
|
||||
|
||||
# put current saved event in buffer (error case)
|
||||
worker.buff = {InventoryUpdateEvent: [InventoryUpdateEvent.objects.get(uuid=events[0].uuid)]}
|
||||
@@ -113,7 +113,7 @@ class TestCallbackBrokerWorker(TransactionTestCase):
|
||||
|
||||
with mock.patch.object(InventoryUpdateEvent.objects, 'bulk_create', side_effect=ValueError):
|
||||
with mock.patch.object(events[0], 'save', side_effect=ValueError):
|
||||
worker.flush(force=True)
|
||||
worker.flush()
|
||||
|
||||
assert "\x00" not in events[0].stdout
|
||||
|
||||
|
||||
@@ -305,47 +305,6 @@ class TestINIImports:
|
||||
has_host_group = inventory.groups.get(name='has_a_host')
|
||||
assert has_host_group.hosts.count() == 1
|
||||
|
||||
@mock.patch.object(inventory_import, 'AnsibleInventoryLoader', MockLoader)
|
||||
def test_overwrite_removes_stale_memberships(self, inventory):
|
||||
"""When overwrite is enabled, host-group and group-group memberships
|
||||
that are no longer in the imported data should be removed."""
|
||||
# First import: parent_group has two children, host_group has two hosts
|
||||
inventory_import.AnsibleInventoryLoader._data = {
|
||||
"_meta": {"hostvars": {"host1": {}, "host2": {}}},
|
||||
"all": {"children": ["ungrouped", "parent_group", "child_a", "child_b", "host_group"]},
|
||||
"parent_group": {"children": ["child_a", "child_b"]},
|
||||
"host_group": {"hosts": ["host1", "host2"]},
|
||||
"ungrouped": {"hosts": []},
|
||||
}
|
||||
cmd = inventory_import.Command()
|
||||
cmd.handle(inventory_id=inventory.pk, source=__file__, overwrite=True)
|
||||
|
||||
parent = inventory.groups.get(name='parent_group')
|
||||
assert set(parent.children.values_list('name', flat=True)) == {'child_a', 'child_b'}
|
||||
host_grp = inventory.groups.get(name='host_group')
|
||||
assert set(host_grp.hosts.values_list('name', flat=True)) == {'host1', 'host2'}
|
||||
|
||||
# Second import: child_b removed from parent_group, host2 moved out of host_group
|
||||
inventory_import.AnsibleInventoryLoader._data = {
|
||||
"_meta": {"hostvars": {"host1": {}, "host2": {}}},
|
||||
"all": {"children": ["ungrouped", "parent_group", "child_a", "child_b", "host_group"]},
|
||||
"parent_group": {"children": ["child_a"]},
|
||||
"host_group": {"hosts": ["host1"]},
|
||||
"ungrouped": {"hosts": ["host2"]},
|
||||
}
|
||||
cmd = inventory_import.Command()
|
||||
cmd.handle(inventory_id=inventory.pk, source=__file__, overwrite=True)
|
||||
|
||||
parent.refresh_from_db()
|
||||
host_grp.refresh_from_db()
|
||||
# child_b should be removed from parent_group
|
||||
assert set(parent.children.values_list('name', flat=True)) == {'child_a'}
|
||||
# host2 should be removed from host_group
|
||||
assert set(host_grp.hosts.values_list('name', flat=True)) == {'host1'}
|
||||
# host2 and child_b should still exist in the inventory, just not in those groups
|
||||
assert inventory.hosts.filter(name='host2').exists()
|
||||
assert inventory.groups.filter(name='child_b').exists()
|
||||
|
||||
@mock.patch.object(inventory_import, 'AnsibleInventoryLoader', MockLoader)
|
||||
def test_recursive_group_error(self, inventory):
|
||||
inventory_import.AnsibleInventoryLoader._data = {
|
||||
|
||||
@@ -10,23 +10,9 @@ from django.test.utils import override_settings
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_multiple_hybrid_instances():
|
||||
for i in range(3):
|
||||
Instance.objects.create(hostname=f'foo{i}', node_type='hybrid')
|
||||
assert is_ha_environment()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_double_control_instances():
|
||||
def test_multiple_instances():
|
||||
for i in range(2):
|
||||
Instance.objects.create(hostname=f'foo{i}', node_type='control')
|
||||
assert is_ha_environment()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_mix_hybrid_control_instances():
|
||||
Instance.objects.create(hostname='control_node', node_type='control')
|
||||
Instance.objects.create(hostname='hybrid_node', node_type='hybrid')
|
||||
Instance.objects.create(hostname=f'foo{i}', node_type='hybrid')
|
||||
assert is_ha_environment()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from awx.main.models import JobTemplate, Job, JobHostSummary, WorkflowJob, Inventory, Host, Project, Organization
|
||||
from awx.main.models import JobTemplate, Job, JobHostSummary, WorkflowJob, Inventory, Project, Organization
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -87,47 +87,3 @@ class TestSlicingModels:
|
||||
|
||||
unified_job = job_template.create_unified_job(job_slice_count=2)
|
||||
assert isinstance(unified_job, Job)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGetSourceHostsForConstructedInventory:
|
||||
"""Tests for Job.get_source_hosts_for_constructed_inventory"""
|
||||
|
||||
def test_returns_source_hosts_via_instance_id(self):
|
||||
"""Constructed hosts with instance_id pointing to source hosts are resolved correctly."""
|
||||
org = Organization.objects.create(name='test-org')
|
||||
inv_input = Inventory.objects.create(organization=org, name='input-inv')
|
||||
source_host1 = inv_input.hosts.create(name='host1')
|
||||
source_host2 = inv_input.hosts.create(name='host2')
|
||||
|
||||
inv_constructed = Inventory.objects.create(organization=org, name='constructed-inv', kind='constructed')
|
||||
inv_constructed.input_inventories.add(inv_input)
|
||||
Host.objects.create(inventory=inv_constructed, name='host1', instance_id=str(source_host1.id))
|
||||
Host.objects.create(inventory=inv_constructed, name='host2', instance_id=str(source_host2.id))
|
||||
|
||||
job = Job.objects.create(name='test-job', inventory=inv_constructed)
|
||||
result = job.get_source_hosts_for_constructed_inventory()
|
||||
|
||||
assert set(result.values_list('id', flat=True)) == {source_host1.id, source_host2.id}
|
||||
|
||||
def test_no_inventory_returns_empty(self):
|
||||
"""A job with no inventory returns an empty queryset."""
|
||||
job = Job.objects.create(name='test-job')
|
||||
result = job.get_source_hosts_for_constructed_inventory()
|
||||
assert result.count() == 0
|
||||
|
||||
def test_ignores_hosts_without_instance_id(self):
|
||||
"""Hosts with empty instance_id are excluded from the result."""
|
||||
org = Organization.objects.create(name='test-org')
|
||||
inv_input = Inventory.objects.create(organization=org, name='input-inv')
|
||||
source_host = inv_input.hosts.create(name='host1')
|
||||
|
||||
inv_constructed = Inventory.objects.create(organization=org, name='constructed-inv', kind='constructed')
|
||||
inv_constructed.input_inventories.add(inv_input)
|
||||
Host.objects.create(inventory=inv_constructed, name='host1', instance_id=str(source_host.id))
|
||||
Host.objects.create(inventory=inv_constructed, name='host-no-ref', instance_id='')
|
||||
|
||||
job = Job.objects.create(name='test-job', inventory=inv_constructed)
|
||||
result = job.get_source_hosts_for_constructed_inventory()
|
||||
|
||||
assert list(result.values_list('id', flat=True)) == [source_host.id]
|
||||
|
||||
@@ -101,34 +101,6 @@ def test_host_access(organization, inventory, group, user, group_factory):
|
||||
assert inventory_admin_access.can_read(host) is False
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_host_access_excludes_constructed_inventory_hosts(organization, inventory, user):
|
||||
"""
|
||||
Exclude hosts from constructed inventory for all users.
|
||||
"""
|
||||
constructed_inv = organization.inventories.create(name='constructed-inv', kind='constructed')
|
||||
real_host = Host.objects.create(inventory=inventory, name='hostA')
|
||||
shadow_host = Host.objects.create(inventory=constructed_inv, name='hostA')
|
||||
|
||||
# Non-superuser with read on both inventories
|
||||
reader = user('reader', False)
|
||||
inventory.read_role.members.add(reader)
|
||||
constructed_inv.read_role.members.add(reader)
|
||||
|
||||
reader_qs = HostAccess(reader).get_queryset()
|
||||
assert real_host in reader_qs
|
||||
assert shadow_host not in reader_qs
|
||||
|
||||
# Superuser path: should get the same result
|
||||
superuser = user('super', True)
|
||||
super_qs = HostAccess(superuser).get_queryset()
|
||||
assert real_host in super_qs
|
||||
assert shadow_host not in super_qs
|
||||
|
||||
# Sanity: shadow rows still exist in the DB and are reachable via inventory filtering
|
||||
assert Host.objects.filter(inventory=constructed_inv, name='hostA').exists()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_inventory_source_credential_check(rando, inventory_source, credential):
|
||||
inventory_source.inventory.admin_role.members.add(rando)
|
||||
|
||||
@@ -1,274 +0,0 @@
|
||||
# Generated by Claude Opus 4.6 (claude-opus-4-6)
|
||||
#
|
||||
# Test file for cancel + dependency chain behavior and workflow cancel propagation.
|
||||
#
|
||||
# These tests verify:
|
||||
#
|
||||
# 1. TaskManager.process_job_dep_failures() correctly distinguishes canceled vs
|
||||
# failed dependencies in the job_explanation message.
|
||||
#
|
||||
# 2. TaskManager.process_pending_tasks() transitions pending jobs with
|
||||
# cancel_flag=True directly to canceled status.
|
||||
#
|
||||
# 3. WorkflowManager + TaskManager together cancel all spawned jobs in a
|
||||
# workflow and finalize the workflow as canceled.
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
|
||||
from awx.main.models import JobTemplate, ProjectUpdate, WorkflowApproval, WorkflowJobTemplate
|
||||
from awx.main.models.workflow import WorkflowApprovalTemplate
|
||||
from . import create_job
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scm_on_launch_objects(job_template_factory):
|
||||
"""Create a job template with a project configured for scm_update_on_launch."""
|
||||
objects = job_template_factory(
|
||||
'jt',
|
||||
organization='org1',
|
||||
project='proj',
|
||||
inventory='inv',
|
||||
credential='cred',
|
||||
)
|
||||
p = objects.project
|
||||
p.scm_update_on_launch = True
|
||||
p.scm_update_cache_timeout = 0
|
||||
p.save(skip_update=True)
|
||||
return objects
|
||||
|
||||
|
||||
def _create_job_with_dependency(objects):
|
||||
"""Create a pending job and run DependencyManager to produce its project update dependency.
|
||||
|
||||
Returns (job, project_update).
|
||||
"""
|
||||
j = create_job(objects.job_template, dependencies_processed=False)
|
||||
with mock.patch('awx.main.models.unified_jobs.UnifiedJobTemplate.update'):
|
||||
DependencyManager().schedule()
|
||||
assert j.dependent_jobs.count() == 1
|
||||
pu = j.dependent_jobs.first()
|
||||
assert isinstance(pu.get_real_instance(), ProjectUpdate)
|
||||
return j, pu
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestCanceledDependencyFailsBlockedJob:
|
||||
"""When a dependency project update is canceled or failed, the task manager
|
||||
should fail the blocked job via process_job_dep_failures."""
|
||||
|
||||
def test_canceled_dependency_fails_blocked_job(self, controlplane_instance_group, scm_on_launch_objects):
|
||||
"""A canceled dependency causes the blocked job to be failed with
|
||||
a 'Previous Task Canceled' explanation."""
|
||||
j, pu = _create_job_with_dependency(scm_on_launch_objects)
|
||||
|
||||
ProjectUpdate.objects.filter(pk=pu.pk).update(status='canceled', cancel_flag=True)
|
||||
|
||||
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
|
||||
TaskManager().schedule()
|
||||
|
||||
j.refresh_from_db()
|
||||
assert j.status == 'failed'
|
||||
assert 'Previous Task Canceled' in j.job_explanation
|
||||
|
||||
def test_failed_dependency_fails_blocked_job(self, controlplane_instance_group, scm_on_launch_objects):
|
||||
"""A failed dependency causes the blocked job to be failed with
|
||||
a 'Previous Task Failed' explanation."""
|
||||
j, pu = _create_job_with_dependency(scm_on_launch_objects)
|
||||
|
||||
ProjectUpdate.objects.filter(pk=pu.pk).update(status='failed')
|
||||
|
||||
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
|
||||
TaskManager().schedule()
|
||||
|
||||
j.refresh_from_db()
|
||||
assert j.status == 'failed'
|
||||
assert 'Previous Task Failed' in j.job_explanation
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestTaskManagerCancelsPendingJobsWithCancelFlag:
|
||||
"""When the task manager encounters pending jobs that have cancel_flag set,
|
||||
it should transition them directly to canceled status."""
|
||||
|
||||
def test_pending_job_with_cancel_flag_is_canceled(self, controlplane_instance_group, job_template_factory):
|
||||
"""A pending job with cancel_flag=True is transitioned to canceled
|
||||
by the task manager without being started."""
|
||||
objects = job_template_factory(
|
||||
'jt',
|
||||
organization='org1',
|
||||
project='proj',
|
||||
inventory='inv',
|
||||
credential='cred',
|
||||
)
|
||||
j = create_job(objects.job_template)
|
||||
j.cancel_flag = True
|
||||
j.save(update_fields=['cancel_flag'])
|
||||
|
||||
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
|
||||
TaskManager().schedule()
|
||||
|
||||
j.refresh_from_db()
|
||||
assert j.status == 'canceled'
|
||||
assert 'canceled before it started' in j.job_explanation
|
||||
assert not mock_start.called
|
||||
|
||||
def test_pending_job_without_cancel_flag_is_not_canceled(self, controlplane_instance_group, job_template_factory):
|
||||
"""A normal pending job without cancel_flag should not be canceled
|
||||
by the task manager (sanity check)."""
|
||||
objects = job_template_factory(
|
||||
'jt',
|
||||
organization='org1',
|
||||
project='proj',
|
||||
inventory='inv',
|
||||
credential='cred',
|
||||
)
|
||||
j = create_job(objects.job_template)
|
||||
|
||||
with mock.patch("awx.main.scheduler.TaskManager.start_task"):
|
||||
TaskManager().schedule()
|
||||
|
||||
j.refresh_from_db()
|
||||
assert j.status != 'canceled'
|
||||
|
||||
def test_multiple_pending_jobs_with_cancel_flag_bulk_canceled(self, controlplane_instance_group, job_template_factory):
|
||||
"""Multiple pending jobs with cancel_flag=True are all transitioned
|
||||
to canceled in a single task manager cycle."""
|
||||
objects = job_template_factory(
|
||||
'jt',
|
||||
organization='org1',
|
||||
project='proj',
|
||||
inventory='inv',
|
||||
credential='cred',
|
||||
)
|
||||
jt = objects.job_template
|
||||
jt.allow_simultaneous = True
|
||||
jt.save()
|
||||
|
||||
jobs = []
|
||||
for _ in range(3):
|
||||
j = create_job(jt)
|
||||
j.cancel_flag = True
|
||||
j.save(update_fields=['cancel_flag'])
|
||||
jobs.append(j)
|
||||
|
||||
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
|
||||
TaskManager().schedule()
|
||||
|
||||
for j in jobs:
|
||||
j.refresh_from_db()
|
||||
assert j.status == 'canceled', f"Job {j.id} should be canceled but is {j.status}"
|
||||
assert 'canceled before it started' in j.job_explanation
|
||||
assert not mock_start.called
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestWorkflowCancelFinalizesWorkflow:
|
||||
"""When a workflow job is canceled, the WorkflowManager cancels spawned child
|
||||
jobs (setting cancel_flag), the TaskManager transitions those pending jobs to
|
||||
canceled, and a final WorkflowManager pass finalizes the workflow as canceled."""
|
||||
|
||||
def test_cancel_workflow_with_parallel_nodes(self, inventory, project, controlplane_instance_group):
|
||||
"""Create a workflow with parallel nodes, cancel it after one job is
|
||||
running, and verify all jobs and the workflow reach canceled status."""
|
||||
jt = JobTemplate.objects.create(allow_simultaneous=False, inventory=inventory, project=project, playbook='helloworld.yml')
|
||||
wfjt = WorkflowJobTemplate.objects.create(name='test-cancel-wf')
|
||||
for _ in range(4):
|
||||
wfjt.workflow_nodes.create(unified_job_template=jt)
|
||||
|
||||
wj = wfjt.create_unified_job()
|
||||
wj.signal_start()
|
||||
|
||||
# TaskManager transitions workflow job to running via start_task
|
||||
TaskManager().schedule()
|
||||
wj.refresh_from_db()
|
||||
assert wj.status == 'running'
|
||||
|
||||
# WorkflowManager spawns jobs for all 4 nodes
|
||||
WorkflowManager().schedule()
|
||||
assert jt.jobs.count() == 4
|
||||
|
||||
# Simulate one job running (blocking the others via allow_simultaneous=False)
|
||||
first_job = jt.jobs.order_by('created').first()
|
||||
first_job.status = 'running'
|
||||
first_job.celery_task_id = 'fake-task-id'
|
||||
first_job.controller_node = 'test-node'
|
||||
first_job.save(update_fields=['status', 'celery_task_id', 'controller_node'])
|
||||
|
||||
# Cancel the workflow
|
||||
wj.cancel_flag = True
|
||||
wj.save(update_fields=['cancel_flag'])
|
||||
|
||||
# WorkflowManager sees cancel_flag, calls cancel_node_jobs() which sets
|
||||
# cancel_flag on all child jobs
|
||||
with mock.patch('awx.main.models.unified_jobs.UnifiedJob.cancel_dispatcher_process'):
|
||||
WorkflowManager().schedule()
|
||||
|
||||
# The running job won't actually stop in tests (no dispatcher), simulate it
|
||||
first_job.status = 'canceled'
|
||||
first_job.save(update_fields=['status'])
|
||||
|
||||
# TaskManager processes remaining pending jobs with cancel_flag set
|
||||
with mock.patch("awx.main.scheduler.TaskManager.start_task") as mock_start:
|
||||
DependencyManager().schedule()
|
||||
TaskManager().schedule()
|
||||
|
||||
for job in jt.jobs.all():
|
||||
job.refresh_from_db()
|
||||
assert job.status == 'canceled', f"Job {job.id} should be canceled but is {job.status}"
|
||||
assert not mock_start.called
|
||||
|
||||
# Final WorkflowManager pass finalizes the workflow
|
||||
WorkflowManager().schedule()
|
||||
wj.refresh_from_db()
|
||||
assert wj.status == 'canceled'
|
||||
|
||||
def test_cancel_workflow_with_approval_node(self, controlplane_instance_group):
|
||||
"""Create a workflow with a pending approval node and a downstream job
|
||||
node. Cancel the workflow and verify the approval is directly canceled
|
||||
by the WorkflowManager (since approvals are excluded from TaskManager),
|
||||
the downstream node is marked do_not_run, and the workflow finalizes."""
|
||||
approval_template = WorkflowApprovalTemplate.objects.create(name='test-approval', timeout=0)
|
||||
wfjt = WorkflowJobTemplate.objects.create(name='test-cancel-approval-wf')
|
||||
approval_node = wfjt.workflow_nodes.create(unified_job_template=approval_template)
|
||||
|
||||
# Add a downstream node (just another approval to keep it simple)
|
||||
downstream_template = WorkflowApprovalTemplate.objects.create(name='test-downstream', timeout=0)
|
||||
downstream_node = wfjt.workflow_nodes.create(unified_job_template=downstream_template)
|
||||
approval_node.success_nodes.add(downstream_node)
|
||||
|
||||
wj = wfjt.create_unified_job()
|
||||
wj.signal_start()
|
||||
|
||||
# TaskManager transitions workflow to running
|
||||
TaskManager().schedule()
|
||||
wj.refresh_from_db()
|
||||
assert wj.status == 'running'
|
||||
|
||||
# WorkflowManager spawns the approval (root node only, downstream waits)
|
||||
WorkflowManager().schedule()
|
||||
assert WorkflowApproval.objects.filter(unified_job_node__workflow_job=wj).count() == 1
|
||||
|
||||
approval_job = WorkflowApproval.objects.get(unified_job_node__workflow_job=wj)
|
||||
assert approval_job.status == 'pending'
|
||||
|
||||
# Cancel the workflow
|
||||
wj.cancel_flag = True
|
||||
wj.save(update_fields=['cancel_flag'])
|
||||
|
||||
# WorkflowManager should cancel the approval directly and mark
|
||||
# the downstream node as do_not_run
|
||||
WorkflowManager().schedule()
|
||||
|
||||
approval_job.refresh_from_db()
|
||||
assert approval_job.status == 'canceled', f"Approval should be canceled directly by WorkflowManager but is {approval_job.status}"
|
||||
|
||||
# Downstream node should be marked do_not_run with no job spawned
|
||||
downstream_wj_node = wj.workflow_nodes.get(unified_job_template=downstream_template)
|
||||
assert downstream_wj_node.do_not_run is True
|
||||
assert downstream_wj_node.job is None
|
||||
|
||||
# Workflow should finalize as canceled in the same pass
|
||||
wj.refresh_from_db()
|
||||
assert wj.status == 'canceled'
|
||||
@@ -1,223 +0,0 @@
|
||||
"""Functional tests for start_fact_cache / finish_fact_cache.
|
||||
|
||||
These tests use real database objects (via pytest-django) and real files
|
||||
on disk, but do not launch jobs or subprocesses. Fact files are written
|
||||
by start_fact_cache and then manipulated to simulate ansible output
|
||||
before calling finish_fact_cache.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from django.utils.timezone import now
|
||||
|
||||
from awx.main.models import Host, Inventory
|
||||
from awx.main.tasks.facts import start_fact_cache, finish_fact_cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def artifacts_dir(tmp_path):
|
||||
d = tmp_path / 'artifacts'
|
||||
d.mkdir()
|
||||
return str(d)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestFinishFactCacheScoping:
|
||||
"""finish_fact_cache must only update hosts matched by the provided queryset."""
|
||||
|
||||
def test_same_hostname_different_inventories(self, organization, artifacts_dir):
|
||||
"""Two inventories share a hostname; only the targeted one should be updated.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
"""
|
||||
inv1 = Inventory.objects.create(organization=organization, name='scope-inv1')
|
||||
inv2 = Inventory.objects.create(organization=organization, name='scope-inv2')
|
||||
|
||||
host1 = inv1.hosts.create(name='shared')
|
||||
host2 = inv2.hosts.create(name='shared')
|
||||
|
||||
# Give both hosts initial facts
|
||||
for h in (host1, host2):
|
||||
h.ansible_facts = {'original': True}
|
||||
h.ansible_facts_modified = now()
|
||||
h.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
# start_fact_cache writes reference files for inv1's hosts
|
||||
start_fact_cache(inv1.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv1.id)
|
||||
|
||||
# Simulate ansible writing new facts for 'shared'
|
||||
fact_file = os.path.join(artifacts_dir, 'fact_cache', 'shared')
|
||||
future = time.time() + 60
|
||||
with open(fact_file, 'w') as f:
|
||||
json.dump({'updated': True}, f)
|
||||
os.utime(fact_file, (future, future))
|
||||
|
||||
# finish with inv1's hosts as the queryset
|
||||
finish_fact_cache(inv1.hosts, artifacts_dir=artifacts_dir, inventory_id=inv1.id)
|
||||
|
||||
host1.refresh_from_db()
|
||||
host2.refresh_from_db()
|
||||
|
||||
assert host1.ansible_facts == {'updated': True}
|
||||
assert host2.ansible_facts == {'original': True}, 'Host in a different inventory was modified despite not being in the queryset'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestFinishFactCacheConcurrentProtection:
|
||||
"""finish_fact_cache must not clear facts that a concurrent job updated."""
|
||||
|
||||
def test_no_clear_when_no_file_was_written(self, organization, artifacts_dir):
|
||||
"""Host with no prior facts should not have facts cleared when file is missing.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
|
||||
start_fact_cache records hosts_cached[host] = False for hosts with no
|
||||
prior facts (no file written). finish_fact_cache should skip the clear
|
||||
for these hosts because the missing file is expected, not a clear signal.
|
||||
"""
|
||||
inv = Inventory.objects.create(organization=organization, name='concurrent-inv')
|
||||
host = inv.hosts.create(name='target')
|
||||
|
||||
job_created = now() - timedelta(minutes=5)
|
||||
|
||||
# start_fact_cache records host with False (no facts → no file written)
|
||||
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
|
||||
|
||||
# Simulate a concurrent job updating this host's facts AFTER our job was created
|
||||
host.ansible_facts = {'from_concurrent_job': True}
|
||||
host.ansible_facts_modified = now()
|
||||
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
# The fact file is missing because start_fact_cache never wrote one.
|
||||
# finish_fact_cache should skip this host entirely.
|
||||
finish_fact_cache(
|
||||
inv.hosts,
|
||||
artifacts_dir=artifacts_dir,
|
||||
inventory_id=inv.id,
|
||||
job_created=job_created,
|
||||
)
|
||||
|
||||
host.refresh_from_db()
|
||||
assert host.ansible_facts == {'from_concurrent_job': True}, 'Facts were cleared for a host that never had a fact file written'
|
||||
|
||||
def test_skip_clear_when_facts_modified_after_job_created(self, organization, artifacts_dir):
|
||||
"""If a file was written and then deleted, but facts were concurrently updated, skip clear.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
"""
|
||||
inv = Inventory.objects.create(organization=organization, name='concurrent-written-inv')
|
||||
host = inv.hosts.create(name='target')
|
||||
|
||||
old_time = now() - timedelta(hours=1)
|
||||
host.ansible_facts = {'original': True}
|
||||
host.ansible_facts_modified = old_time
|
||||
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
job_created = now() - timedelta(minutes=5)
|
||||
|
||||
# start_fact_cache writes a file (host has facts → True in map)
|
||||
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
|
||||
|
||||
# Remove the fact file (ansible didn't target this host via --limit)
|
||||
os.remove(os.path.join(artifacts_dir, 'fact_cache', host.name))
|
||||
|
||||
# Simulate a concurrent job updating this host's facts AFTER our job was created
|
||||
host.ansible_facts = {'from_concurrent_job': True}
|
||||
host.ansible_facts_modified = now()
|
||||
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
finish_fact_cache(
|
||||
inv.hosts,
|
||||
artifacts_dir=artifacts_dir,
|
||||
inventory_id=inv.id,
|
||||
job_created=job_created,
|
||||
)
|
||||
|
||||
host.refresh_from_db()
|
||||
assert host.ansible_facts == {'from_concurrent_job': True}, 'Facts set by a concurrent job were cleared despite ansible_facts_modified > job_created'
|
||||
|
||||
def test_clear_when_facts_predate_job(self, organization, artifacts_dir):
|
||||
"""If facts predate the job, a missing file should still clear them.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
"""
|
||||
inv = Inventory.objects.create(organization=organization, name='clear-inv')
|
||||
host = inv.hosts.create(name='stale')
|
||||
|
||||
old_time = now() - timedelta(hours=1)
|
||||
host.ansible_facts = {'stale': True}
|
||||
host.ansible_facts_modified = old_time
|
||||
host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
job_created = now() - timedelta(minutes=5)
|
||||
|
||||
start_fact_cache(inv.hosts.all(), artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv.id)
|
||||
|
||||
# Remove the fact file to simulate ansible's clear_facts
|
||||
os.remove(os.path.join(artifacts_dir, 'fact_cache', host.name))
|
||||
|
||||
finish_fact_cache(
|
||||
inv.hosts,
|
||||
artifacts_dir=artifacts_dir,
|
||||
inventory_id=inv.id,
|
||||
job_created=job_created,
|
||||
)
|
||||
|
||||
host.refresh_from_db()
|
||||
assert host.ansible_facts == {}, 'Stale facts should have been cleared when the fact file is missing ' 'and ansible_facts_modified predates job_created'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestConstructedInventoryFactCache:
|
||||
"""finish_fact_cache with a constructed inventory queryset must target source hosts."""
|
||||
|
||||
def test_facts_resolve_to_source_host(self, organization, artifacts_dir):
|
||||
"""Facts must be written to the source host, not the constructed copy.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
"""
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
inv_input = Inventory.objects.create(organization=organization, name='ci-input')
|
||||
source_host = inv_input.hosts.create(name='webserver')
|
||||
|
||||
inv_constructed = Inventory.objects.create(organization=organization, name='ci-constructed', kind='constructed')
|
||||
inv_constructed.input_inventories.add(inv_input)
|
||||
constructed_host = Host.objects.create(
|
||||
inventory=inv_constructed,
|
||||
name='webserver',
|
||||
instance_id=str(source_host.id),
|
||||
)
|
||||
|
||||
# Build the same queryset that get_hosts_for_fact_cache uses
|
||||
id_field = Host._meta.get_field('id')
|
||||
source_qs = Host.objects.filter(id__in=inv_constructed.hosts.exclude(instance_id='').values_list(Cast('instance_id', output_field=id_field)))
|
||||
|
||||
# Give the source host initial facts so start_fact_cache writes a file
|
||||
source_host.ansible_facts = {'role': 'web'}
|
||||
source_host.ansible_facts_modified = now()
|
||||
source_host.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
start_fact_cache(source_qs, artifacts_dir=artifacts_dir, timeout=0, inventory_id=inv_constructed.id)
|
||||
|
||||
# Simulate ansible writing updated facts
|
||||
fact_file = os.path.join(artifacts_dir, 'fact_cache', 'webserver')
|
||||
future = time.time() + 60
|
||||
with open(fact_file, 'w') as f:
|
||||
json.dump({'role': 'web', 'deployed': True}, f)
|
||||
os.utime(fact_file, (future, future))
|
||||
|
||||
finish_fact_cache(source_qs, artifacts_dir=artifacts_dir, inventory_id=inv_constructed.id)
|
||||
|
||||
source_host.refresh_from_db()
|
||||
constructed_host.refresh_from_db()
|
||||
|
||||
assert source_host.ansible_facts == {'role': 'web', 'deployed': True}
|
||||
assert not constructed_host.ansible_facts, f'Facts were stored on the constructed host: {constructed_host.ansible_facts!r}'
|
||||
@@ -29,30 +29,3 @@ def test_cancel_flag_on_start(jt_linked, caplog):
|
||||
|
||||
job = Job.objects.get(id=job.id)
|
||||
assert job.status == 'canceled'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_runjob_run_can_accept_waiting_status(jt_linked, mocker):
|
||||
"""Test that RunJob.run() can accept a job in 'waiting' status and transition it to 'running'
|
||||
before the pre_run_hook is called"""
|
||||
job = jt_linked.create_unified_job()
|
||||
job.status = 'waiting'
|
||||
job.save()
|
||||
|
||||
status_at_pre_run = None
|
||||
|
||||
def capture_status(instance, private_data_dir):
|
||||
nonlocal status_at_pre_run
|
||||
instance.refresh_from_db()
|
||||
status_at_pre_run = instance.status
|
||||
|
||||
mock_pre_run = mocker.patch.object(RunJob, 'pre_run_hook', side_effect=capture_status)
|
||||
|
||||
task = RunJob()
|
||||
try:
|
||||
task.run(job.id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_pre_run.assert_called_once()
|
||||
assert status_at_pre_run == 'running'
|
||||
|
||||
@@ -445,142 +445,3 @@ def get_inventory_hosts(get, inv_id, use_user):
|
||||
data = get(reverse('api:inventory_hosts_list', kwargs={'pk': inv_id}), use_user, expect=200).data
|
||||
results = [host['id'] for host in data['results']]
|
||||
return results
|
||||
|
||||
|
||||
# Tests for BulkHostCreateSerializer duplicate detection optimization
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_duplicate_within_batch(organization, inventory, post, user):
|
||||
"""
|
||||
Test that duplicate hostnames within the same batch are detected.
|
||||
This tests the Counter-based duplicate detection logic.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inv_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inv_admin)
|
||||
inventory.admin_role.members.add(inv_admin)
|
||||
|
||||
# Try to create hosts where 'duplicate-host' appears twice in the same batch
|
||||
hosts = [
|
||||
{'name': 'unique-host-1'},
|
||||
{'name': 'duplicate-host'},
|
||||
{'name': 'unique-host-2'},
|
||||
{'name': 'duplicate-host'}, # Duplicate within batch
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inv_admin, expect=400)
|
||||
|
||||
assert 'Hostnames must be unique in an inventory' in response.data['__all__'][0]
|
||||
assert 'duplicate-host' in response.data['__all__'][0]
|
||||
assert Host.objects.filter(inventory=inventory).count() == 0
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_duplicate_against_existing(organization, inventory, post, user):
|
||||
"""
|
||||
Test that duplicate hostnames against existing inventory hosts are detected.
|
||||
This tests the database query-based duplicate detection.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inv_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inv_admin)
|
||||
inventory.admin_role.members.add(inv_admin)
|
||||
|
||||
Host.objects.create(name='existing-host-1', inventory=inventory)
|
||||
Host.objects.create(name='existing-host-2', inventory=inventory)
|
||||
|
||||
# Try to create hosts where one already exists
|
||||
hosts = [
|
||||
{'name': 'new-host-1'},
|
||||
{'name': 'existing-host-1'},
|
||||
{'name': 'new-host-2'},
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inv_admin, expect=400)
|
||||
|
||||
assert 'Hostnames must be unique in an inventory' in response.data['__all__'][0]
|
||||
assert 'existing-host-1' in response.data['__all__'][0]
|
||||
assert Host.objects.filter(inventory=inventory).count() == 2
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_combined_duplicates(organization, inventory, post, user):
|
||||
"""
|
||||
Test detection of both batch-internal duplicates and duplicates against existing hosts.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inventory_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inventory_admin)
|
||||
inventory.admin_role.members.add(inventory_admin)
|
||||
|
||||
Host.objects.create(name='existing-host', inventory=inventory)
|
||||
|
||||
# Try to create hosts with both types of duplicates
|
||||
hosts = [
|
||||
{'name': 'new-host'},
|
||||
{'name': 'batch-duplicate'},
|
||||
{'name': 'existing-host'},
|
||||
{'name': 'batch-duplicate'},
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inventory_admin, expect=400)
|
||||
|
||||
error_message = response.data['__all__'][0]
|
||||
assert 'Hostnames must be unique in an inventory' in error_message
|
||||
assert 'batch-duplicate' in error_message or 'existing-host' in error_message
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_no_duplicates_success(organization, inventory, post, user):
|
||||
"""
|
||||
Test that hosts are created successfully when there are no duplicates.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inventory_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inventory_admin)
|
||||
inventory.admin_role.members.add(inventory_admin)
|
||||
|
||||
Host.objects.create(name='existing-host-1', inventory=inventory)
|
||||
Host.objects.create(name='existing-host-2', inventory=inventory)
|
||||
|
||||
# Create new hosts with unique names
|
||||
hosts = [
|
||||
{'name': 'new-host-1'},
|
||||
{'name': 'new-host-2'},
|
||||
{'name': 'new-host-3'},
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inventory_admin, expect=201)
|
||||
|
||||
assert len(response.data['hosts']) == 3
|
||||
assert Host.objects.filter(inventory=inventory).count() == 5
|
||||
assert Host.objects.filter(inventory=inventory, name='new-host-1').exists()
|
||||
assert Host.objects.filter(inventory=inventory, name='new-host-2').exists()
|
||||
assert Host.objects.filter(inventory=inventory, name='new-host-3').exists()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_performance_large_inventory(organization, inventory, post, user, django_assert_max_num_queries):
|
||||
"""
|
||||
Test that duplicate detection is performant and doesn't load all hosts.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inventory_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inventory_admin)
|
||||
inventory.admin_role.members.add(inventory_admin)
|
||||
|
||||
# Create 10k existing hosts to simulate a reasonably large inventory
|
||||
from django.utils.timezone import now
|
||||
|
||||
_now = now()
|
||||
existing_hosts = [Host(name=f'existing-host-{i}', inventory=inventory, created=_now, modified=_now) for i in range(10000)]
|
||||
Host.objects.bulk_create(existing_hosts)
|
||||
|
||||
new_hosts = [{'name': f'new-host-{i}'} for i in range(10)]
|
||||
|
||||
# The number of queries should be bounded and not scale with inventory size
|
||||
# This should be around 15-20 queries regardless of whether there are 10k or 500k+ existing hosts
|
||||
with django_assert_max_num_queries(20):
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': new_hosts}, inventory_admin, expect=201)
|
||||
|
||||
assert len(response.data['hosts']) == 10
|
||||
assert Host.objects.filter(inventory=inventory).count() == 10010
|
||||
|
||||
@@ -74,64 +74,47 @@ GLqbpJyX2r3p/Rmo6mLY71SqpA==
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_default_cred_types():
|
||||
expected = [
|
||||
'aim',
|
||||
'aws',
|
||||
'aws_secretsmanager_credential',
|
||||
'azure_kv',
|
||||
'azure_rm',
|
||||
'bitbucket_dc_token',
|
||||
'centrify_vault_kv',
|
||||
'conjur',
|
||||
'controller',
|
||||
'galaxy_api_token',
|
||||
'gce',
|
||||
'github_token',
|
||||
'github_app_lookup',
|
||||
'gitlab_token',
|
||||
'gpg_public_key',
|
||||
'hashivault_kv',
|
||||
'hashivault_ssh',
|
||||
'hcp_terraform',
|
||||
'insights',
|
||||
'kubernetes_bearer_token',
|
||||
'net',
|
||||
'openstack',
|
||||
'registry',
|
||||
'rhv',
|
||||
'satellite6',
|
||||
'scm',
|
||||
'ssh',
|
||||
'terraform',
|
||||
'thycotic_dsv',
|
||||
'thycotic_tss',
|
||||
'vault',
|
||||
'vmware',
|
||||
]
|
||||
assert sorted(CredentialType.defaults.keys()) == sorted(expected)
|
||||
assert 'hashivault-kv-oidc' not in CredentialType.defaults
|
||||
assert 'hashivault-ssh-oidc' not in CredentialType.defaults
|
||||
assert sorted(CredentialType.defaults.keys()) == sorted(
|
||||
[
|
||||
'aim',
|
||||
'aws',
|
||||
'aws_secretsmanager_credential',
|
||||
'azure_kv',
|
||||
'azure_rm',
|
||||
'bitbucket_dc_token',
|
||||
'centrify_vault_kv',
|
||||
'conjur',
|
||||
'controller',
|
||||
'galaxy_api_token',
|
||||
'gce',
|
||||
'github_token',
|
||||
'github_app_lookup',
|
||||
'gitlab_token',
|
||||
'gpg_public_key',
|
||||
'hashivault_kv',
|
||||
'hashivault_ssh',
|
||||
'hcp_terraform',
|
||||
'insights',
|
||||
'kubernetes_bearer_token',
|
||||
'net',
|
||||
'openstack',
|
||||
'registry',
|
||||
'rhv',
|
||||
'satellite6',
|
||||
'scm',
|
||||
'ssh',
|
||||
'terraform',
|
||||
'thycotic_dsv',
|
||||
'thycotic_tss',
|
||||
'vault',
|
||||
'vmware',
|
||||
]
|
||||
)
|
||||
|
||||
for type_ in CredentialType.defaults.values():
|
||||
assert type_().managed is True
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_default_cred_types_with_oidc_enabled():
|
||||
from django.test import override_settings
|
||||
from awx.main.models.credential import load_credentials, ManagedCredentialType
|
||||
|
||||
original_registry = ManagedCredentialType.registry.copy()
|
||||
try:
|
||||
with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True):
|
||||
ManagedCredentialType.registry.clear()
|
||||
load_credentials()
|
||||
assert 'hashivault-kv-oidc' in CredentialType.defaults
|
||||
assert 'hashivault-ssh-oidc' in CredentialType.defaults
|
||||
finally:
|
||||
ManagedCredentialType.registry = original_registry
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_creation(organization_factory):
|
||||
org = organization_factory('test').organization
|
||||
|
||||
@@ -8,7 +8,6 @@ from awx.main.models import (
|
||||
Instance,
|
||||
Host,
|
||||
JobHostSummary,
|
||||
Inventory,
|
||||
InventoryUpdate,
|
||||
InventorySource,
|
||||
Project,
|
||||
@@ -18,60 +17,14 @@ from awx.main.models import (
|
||||
InstanceGroup,
|
||||
Label,
|
||||
ExecutionEnvironment,
|
||||
Credential,
|
||||
CredentialType,
|
||||
CredentialInputSource,
|
||||
Organization,
|
||||
JobTemplate,
|
||||
)
|
||||
from awx.main.tasks import jobs
|
||||
from awx.main.tasks.system import cluster_node_heartbeat
|
||||
from awx.main.utils.db import bulk_update_sorted_by_id
|
||||
from ansible_base.lib.testing.util import feature_flag_enabled, feature_flag_disabled
|
||||
|
||||
from django.db import OperationalError
|
||||
from django.test.utils import override_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_template_with_credentials():
|
||||
"""
|
||||
Factory fixture that creates a job template with specified credentials.
|
||||
|
||||
Usage:
|
||||
job = job_template_with_credentials(ssh_cred, vault_cred)
|
||||
"""
|
||||
|
||||
def _create_job_template(
|
||||
*credentials, org_name='test-org', project_name='test-project', inventory_name='test-inventory', jt_name='test-jt', playbook='test.yml'
|
||||
):
|
||||
"""
|
||||
Create a job template with the given credentials.
|
||||
|
||||
Args:
|
||||
*credentials: Variable number of Credential objects to attach to the job template
|
||||
org_name: Name for the organization
|
||||
project_name: Name for the project
|
||||
inventory_name: Name for the inventory
|
||||
jt_name: Name for the job template
|
||||
playbook: Playbook filename
|
||||
|
||||
Returns:
|
||||
Job instance created from the job template
|
||||
"""
|
||||
org = Organization.objects.create(name=org_name)
|
||||
proj = Project.objects.create(name=project_name, organization=org)
|
||||
inv = Inventory.objects.create(name=inventory_name, organization=org)
|
||||
jt = JobTemplate.objects.create(name=jt_name, project=proj, inventory=inv, playbook=playbook)
|
||||
|
||||
if credentials:
|
||||
jt.credentials.add(*credentials)
|
||||
|
||||
return jt.create_unified_job()
|
||||
|
||||
return _create_job_template
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_orphan_unified_job_creation(instance, inventory):
|
||||
job = Job.objects.create(job_template=None, inventory=inventory, name='hi world')
|
||||
@@ -309,442 +262,3 @@ class TestLaunchConfig:
|
||||
assert config.execution_environment
|
||||
# We just write the PK instead of trying to assign an item, that happens on the save
|
||||
assert config.execution_environment_id == ee.id
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_base_task_credentials_property(job_template_with_credentials):
|
||||
"""Test that _credentials property caches credentials and doesn't re-query."""
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create real credentials
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
vault_type = CredentialType.defaults['vault']()
|
||||
vault_type.save()
|
||||
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
|
||||
|
||||
# Create a job with credentials using fixture
|
||||
job = job_template_with_credentials(ssh_cred, vault_cred)
|
||||
task.instance = job
|
||||
|
||||
# First access should build credentials
|
||||
result1 = task._credentials
|
||||
assert len(result1) == 2
|
||||
assert isinstance(result1, list)
|
||||
|
||||
# Second access should return cached value (we can verify by checking it's the same list object)
|
||||
result2 = task._credentials
|
||||
assert result2 is result1 # Same object reference
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_job_machine_credential(job_template_with_credentials):
|
||||
"""Test _machine_credential returns ssh credential from cache."""
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create credentials
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
vault_type = CredentialType.defaults['vault']()
|
||||
vault_type.save()
|
||||
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
|
||||
|
||||
# Create a job using fixture
|
||||
job = job_template_with_credentials(ssh_cred, vault_cred)
|
||||
task.instance = job
|
||||
|
||||
# Set cached credentials
|
||||
task._credentials = [ssh_cred, vault_cred]
|
||||
|
||||
# Get machine credential
|
||||
result = task._machine_credential
|
||||
assert result == ssh_cred
|
||||
assert result.credential_type.kind == 'ssh'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_job_machine_credential_none(job_template_with_credentials):
|
||||
"""Test _machine_credential returns None when no ssh credential exists."""
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create only vault credential
|
||||
vault_type = CredentialType.defaults['vault']()
|
||||
vault_type.save()
|
||||
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
|
||||
|
||||
job = job_template_with_credentials(vault_cred)
|
||||
task.instance = job
|
||||
|
||||
# Set cached credentials
|
||||
task._credentials = [vault_cred]
|
||||
|
||||
# Get machine credential
|
||||
result = task._machine_credential
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_job_vault_credentials(job_template_with_credentials):
|
||||
"""Test _vault_credentials returns all vault credentials from cache."""
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create credentials
|
||||
vault_type = CredentialType.defaults['vault']()
|
||||
vault_type.save()
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
|
||||
vault_cred1 = Credential.objects.create(credential_type=vault_type, name='vault-1')
|
||||
vault_cred2 = Credential.objects.create(credential_type=vault_type, name='vault-2')
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
|
||||
job = job_template_with_credentials(vault_cred1, ssh_cred, vault_cred2)
|
||||
task.instance = job
|
||||
|
||||
# Set cached credentials
|
||||
task._credentials = [vault_cred1, ssh_cred, vault_cred2]
|
||||
|
||||
# Get vault credentials
|
||||
result = task._vault_credentials
|
||||
assert len(result) == 2
|
||||
assert vault_cred1 in result
|
||||
assert vault_cred2 in result
|
||||
assert ssh_cred not in result
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_job_network_credentials(job_template_with_credentials):
|
||||
"""Test _network_credentials returns all network credentials from cache."""
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create credentials
|
||||
net_type = CredentialType.defaults['net']()
|
||||
net_type.save()
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
|
||||
net_cred = Credential.objects.create(credential_type=net_type, name='net-cred')
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
|
||||
job = job_template_with_credentials(net_cred, ssh_cred)
|
||||
task.instance = job
|
||||
|
||||
# Set cached credentials
|
||||
task._credentials = [net_cred, ssh_cred]
|
||||
|
||||
# Get network credentials
|
||||
result = task._network_credentials
|
||||
assert len(result) == 1
|
||||
assert result[0] == net_cred
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_job_cloud_credentials(job_template_with_credentials):
|
||||
"""Test _cloud_credentials returns all cloud credentials from cache."""
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create credentials
|
||||
aws_type = CredentialType.defaults['aws']()
|
||||
aws_type.save()
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
|
||||
aws_cred = Credential.objects.create(credential_type=aws_type, name='aws-cred')
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
|
||||
job = job_template_with_credentials(aws_cred, ssh_cred)
|
||||
task.instance = job
|
||||
|
||||
# Set cached credentials
|
||||
task._credentials = [aws_cred, ssh_cred]
|
||||
|
||||
# Get cloud credentials
|
||||
result = task._cloud_credentials
|
||||
assert len(result) == 1
|
||||
assert result[0] == aws_cred
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
|
||||
def test_populate_workload_identity_tokens_with_flag_enabled(job_template_with_credentials, mocker):
|
||||
"""Test populate_workload_identity_tokens sets context when flag is enabled."""
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create credential types
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
|
||||
# Create a workload identity credential type
|
||||
hashivault_type = CredentialType(
|
||||
name='HashiCorp Vault Secret Lookup (OIDC)',
|
||||
kind='cloud',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
hashivault_type.save()
|
||||
|
||||
# Create credentials
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'jwt_aud': 'https://vault.example.com'})
|
||||
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
|
||||
|
||||
# Create input source linking source credential to target credential
|
||||
input_source = CredentialInputSource.objects.create(
|
||||
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
|
||||
)
|
||||
|
||||
# Create a job using fixture
|
||||
job = job_template_with_credentials(target_cred, ssh_cred)
|
||||
task.instance = job
|
||||
|
||||
# Override cached_property so the loop uses these exact Python objects
|
||||
task._credentials = [target_cred, ssh_cred]
|
||||
|
||||
# Mock only the HTTP response from the Gateway workload identity endpoint
|
||||
mock_response = mocker.Mock(status_code=200)
|
||||
mock_response.json.return_value = {'jwt': 'eyJ.test.jwt'}
|
||||
|
||||
mock_request = mocker.patch('requests.request', return_value=mock_response, autospec=True)
|
||||
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
# Verify the HTTP call was made to the correct endpoint
|
||||
mock_request.assert_called_once()
|
||||
call_kwargs = mock_request.call_args.kwargs
|
||||
assert call_kwargs['method'] == 'POST'
|
||||
assert '/api/gateway/v1/workload_identity_tokens' in call_kwargs['url']
|
||||
|
||||
# Verify context was set on the credential, keyed by input source PK
|
||||
assert input_source.pk in target_cred.context
|
||||
assert target_cred.context[input_source.pk]['workload_identity_token'] == 'eyJ.test.jwt'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
|
||||
def test_populate_workload_identity_tokens_passes_workload_ttl_from_job_timeout(job_template_with_credentials, mocker):
|
||||
"""Test populate_workload_identity_tokens passes workload_ttl_seconds from get_instance_timeout to the client."""
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
task = jobs.RunJob()
|
||||
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
|
||||
hashivault_type = CredentialType(
|
||||
name='HashiCorp Vault Secret Lookup (OIDC)',
|
||||
kind='cloud',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
hashivault_type.save()
|
||||
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'jwt_aud': 'https://vault.example.com'})
|
||||
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
|
||||
|
||||
CredentialInputSource.objects.create(
|
||||
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
|
||||
)
|
||||
|
||||
job = job_template_with_credentials(target_cred, ssh_cred)
|
||||
job.timeout = 3600
|
||||
job.save()
|
||||
task.instance = job
|
||||
task._credentials = [target_cred, ssh_cred]
|
||||
|
||||
mock_response = mocker.Mock(status_code=200)
|
||||
mock_response.json.return_value = {'jwt': 'eyJ.test.jwt'}
|
||||
mock_request = mocker.patch('requests.request', return_value=mock_response, autospec=True)
|
||||
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
call_kwargs = mock_request.call_args.kwargs
|
||||
assert call_kwargs['method'] == 'POST'
|
||||
json_body = call_kwargs.get('json', {})
|
||||
assert json_body.get('workload_ttl_seconds') == 3600
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_populate_workload_identity_tokens_with_flag_disabled(job_template_with_credentials):
|
||||
"""Test populate_workload_identity_tokens sets error status when flag is disabled."""
|
||||
with feature_flag_disabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create credential types
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
|
||||
# Create a workload identity credential type
|
||||
hashivault_type = CredentialType(
|
||||
name='HashiCorp Vault Secret Lookup (OIDC)',
|
||||
kind='cloud',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
hashivault_type.save()
|
||||
|
||||
# Create credentials
|
||||
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source')
|
||||
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
|
||||
|
||||
# Create input source linking source credential to target credential
|
||||
# Note: Creates the relationship that will trigger the feature flag check
|
||||
CredentialInputSource.objects.create(
|
||||
target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'}
|
||||
)
|
||||
|
||||
# Create a job using fixture
|
||||
job = job_template_with_credentials(target_cred)
|
||||
task.instance = job
|
||||
|
||||
# Set cached credentials
|
||||
task._credentials = [target_cred]
|
||||
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
# Verify job status was set to error
|
||||
job.refresh_from_db()
|
||||
assert job.status == 'error'
|
||||
assert 'FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED' in job.job_explanation
|
||||
assert 'vault-source' in job.job_explanation
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False})
|
||||
def test_populate_workload_identity_tokens_multiple_input_sources_per_credential(job_template_with_credentials, mocker):
|
||||
"""Test that a single credential with two input sources from different workload identity
|
||||
credential types gets a separate JWT token for each input source, keyed by input source PK."""
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create credential types
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
|
||||
# Create two different workload identity credential types
|
||||
hashivault_kv_type = CredentialType(
|
||||
name='HashiCorp Vault Secret Lookup (OIDC)',
|
||||
kind='cloud',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
hashivault_kv_type.save()
|
||||
|
||||
hashivault_ssh_type = CredentialType(
|
||||
name='HashiCorp Vault Signed SSH (OIDC)',
|
||||
kind='cloud',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
hashivault_ssh_type.save()
|
||||
|
||||
# Create source credentials with different audiences
|
||||
source_cred_kv = Credential.objects.create(
|
||||
credential_type=hashivault_kv_type, name='vault-kv-source', inputs={'jwt_aud': 'https://vault-kv.example.com'}
|
||||
)
|
||||
source_cred_ssh = Credential.objects.create(
|
||||
credential_type=hashivault_ssh_type, name='vault-ssh-source', inputs={'jwt_aud': 'https://vault-ssh.example.com'}
|
||||
)
|
||||
|
||||
# Create target credential that uses both sources for different fields
|
||||
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
|
||||
|
||||
# Create two input sources on the same target credential, each for a different field
|
||||
input_source_password = CredentialInputSource.objects.create(
|
||||
target_credential=target_cred, source_credential=source_cred_kv, input_field_name='password', metadata={'path': 'secret/data/password'}
|
||||
)
|
||||
input_source_ssh_key = CredentialInputSource.objects.create(
|
||||
target_credential=target_cred, source_credential=source_cred_ssh, input_field_name='ssh_key_data', metadata={'path': 'secret/data/ssh_key'}
|
||||
)
|
||||
|
||||
# Create a job using fixture
|
||||
job = job_template_with_credentials(target_cred)
|
||||
task.instance = job
|
||||
|
||||
# Override cached_property so the loop uses this exact Python object
|
||||
task._credentials = [target_cred]
|
||||
|
||||
# Mock HTTP responses - return different JWTs for each call
|
||||
response_kv = mocker.Mock(status_code=200)
|
||||
response_kv.json.return_value = {'jwt': 'eyJ.kv.jwt'}
|
||||
|
||||
response_ssh = mocker.Mock(status_code=200)
|
||||
response_ssh.json.return_value = {'jwt': 'eyJ.ssh.jwt'}
|
||||
|
||||
mock_request = mocker.patch('requests.request', side_effect=[response_kv, response_ssh], autospec=True)
|
||||
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
# Verify two separate HTTP calls were made (one per input source)
|
||||
assert mock_request.call_count == 2
|
||||
|
||||
# Verify each call used the correct audience from its source credential
|
||||
audiences_requested = {call.kwargs.get('json', {}).get('audience', '') for call in mock_request.call_args_list}
|
||||
assert 'https://vault-kv.example.com' in audiences_requested
|
||||
assert 'https://vault-ssh.example.com' in audiences_requested
|
||||
|
||||
# Verify context on the target credential has both tokens, keyed by input source PK
|
||||
assert input_source_password.pk in target_cred.context
|
||||
assert input_source_ssh_key.pk in target_cred.context
|
||||
assert target_cred.context[input_source_password.pk]['workload_identity_token'] == 'eyJ.kv.jwt'
|
||||
assert target_cred.context[input_source_ssh_key.pk]['workload_identity_token'] == 'eyJ.ssh.jwt'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_populate_workload_identity_tokens_without_workload_identity_credentials(job_template_with_credentials, mocker):
|
||||
"""Test populate_workload_identity_tokens does nothing when no workload identity credentials."""
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
task = jobs.RunJob()
|
||||
|
||||
# Create only standard credentials (no workload identity)
|
||||
ssh_type = CredentialType.defaults['ssh']()
|
||||
ssh_type.save()
|
||||
vault_type = CredentialType.defaults['vault']()
|
||||
vault_type.save()
|
||||
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
vault_cred = Credential.objects.create(credential_type=vault_type, name='vault-cred')
|
||||
|
||||
# Create a job using fixture
|
||||
job = job_template_with_credentials(ssh_cred, vault_cred)
|
||||
task.instance = job
|
||||
|
||||
# Set cached credentials
|
||||
task._credentials = [ssh_cred, vault_cred]
|
||||
|
||||
mocker.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 123}, autospec=True)
|
||||
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
# Verify no context was set
|
||||
assert not hasattr(ssh_cred, '_context') or ssh_cred.context == {}
|
||||
assert not hasattr(vault_cred, '_context') or vault_cred.context == {}
|
||||
|
||||
@@ -18,14 +18,13 @@ from awx.main.tests.functional.conftest import * # noqa
|
||||
from awx.main.tests.conftest import load_all_credentials # noqa: F401; pylint: disable=unused-import
|
||||
from awx.main.tests import data
|
||||
|
||||
from awx.main.models import Project, JobTemplate, Organization, Inventory, WorkflowJob, UnifiedJob
|
||||
from awx.main.models import Project, JobTemplate, Organization, Inventory
|
||||
from awx.main.tasks.system import clear_setting_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
PROJ_DATA = os.path.join(os.path.dirname(data.__file__), 'projects')
|
||||
COLL_DATA = os.path.join(os.path.dirname(data.__file__), 'collections')
|
||||
|
||||
|
||||
def _copy_folders(source_path, dest_path, clear=False):
|
||||
@@ -57,7 +56,6 @@ def live_tmp_folder():
|
||||
shutil.rmtree(path)
|
||||
os.mkdir(path)
|
||||
_copy_folders(PROJ_DATA, path)
|
||||
_copy_folders(COLL_DATA, path)
|
||||
for dirname in os.listdir(path):
|
||||
source_dir = os.path.join(path, dirname)
|
||||
subprocess.run(GIT_COMMANDS, cwd=source_dir, shell=True)
|
||||
@@ -102,21 +100,6 @@ def wait_for_events(uj, timeout=2):
|
||||
|
||||
|
||||
def unified_job_stdout(uj):
|
||||
if type(uj) is UnifiedJob:
|
||||
uj = uj.get_real_instance()
|
||||
if isinstance(uj, WorkflowJob):
|
||||
outputs = []
|
||||
for node in uj.workflow_job_nodes.all().select_related('job').order_by('id'):
|
||||
if node.job is None:
|
||||
continue
|
||||
outputs.append(
|
||||
'workflow node {node_id} job {job_id} output:\n{output}'.format(
|
||||
node_id=node.id,
|
||||
job_id=node.job.id,
|
||||
output=unified_job_stdout(node.job),
|
||||
)
|
||||
)
|
||||
return '\n'.join(outputs)
|
||||
wait_for_events(uj)
|
||||
return '\n'.join([event.stdout for event in uj.get_event_queryset().order_by('created')])
|
||||
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
"""Tests for concurrent fact caching with --limit.
|
||||
|
||||
Reproduces bugs where concurrent jobs targeting different hosts via --limit
|
||||
incorrectly modify (clear or revert) facts for hosts outside their limit.
|
||||
|
||||
Customer report: concurrent jobs on the same job template with different limits
|
||||
cause facts set by an earlier-finishing job to be rolled back when the
|
||||
later-finishing job completes.
|
||||
|
||||
See: https://github.com/jritter/concurrent-aap-fact-caching
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from django.utils.timezone import now
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.models import Inventory, JobTemplate
|
||||
from awx.main.tests.live.tests.conftest import wait_for_job, wait_to_leave_status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def concurrent_facts_inventory(default_org):
|
||||
"""Inventory with two hosts for concurrent fact cache testing."""
|
||||
inv_name = 'test_concurrent_fact_cache'
|
||||
Inventory.objects.filter(organization=default_org, name=inv_name).delete()
|
||||
inv = Inventory.objects.create(organization=default_org, name=inv_name)
|
||||
inv.hosts.create(name='cc_host_0')
|
||||
inv.hosts.create(name='cc_host_1')
|
||||
return inv
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def concurrent_facts_jt(concurrent_facts_inventory, live_tmp_folder, post, admin, project_factory):
|
||||
"""Job template configured for concurrent fact-cached runs."""
|
||||
proj = project_factory(scm_url=f'file://{live_tmp_folder}/facts')
|
||||
if proj.current_job:
|
||||
wait_for_job(proj.current_job)
|
||||
assert 'gather_slow.yml' in proj.playbooks, f'gather_slow.yml not in {proj.playbooks}'
|
||||
|
||||
jt_name = 'test_concurrent_fact_cache JT'
|
||||
existing_jt = JobTemplate.objects.filter(name=jt_name).first()
|
||||
if existing_jt:
|
||||
existing_jt.delete()
|
||||
result = post(
|
||||
reverse('api:job_template_list'),
|
||||
{
|
||||
'name': jt_name,
|
||||
'project': proj.id,
|
||||
'playbook': 'gather_slow.yml',
|
||||
'inventory': concurrent_facts_inventory.id,
|
||||
'use_fact_cache': True,
|
||||
'allow_simultaneous': True,
|
||||
},
|
||||
admin,
|
||||
expect=201,
|
||||
)
|
||||
return JobTemplate.objects.get(id=result.data['id'])
|
||||
|
||||
|
||||
def test_concurrent_limit_does_not_clear_facts(concurrent_facts_inventory, concurrent_facts_jt):
|
||||
"""Concurrent jobs with different --limit must not clear each other's facts.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
|
||||
Scenario:
|
||||
- Inventory has cc_host_0 and cc_host_1, neither has prior facts
|
||||
- Job A runs gather_slow.yml with limit=cc_host_0
|
||||
- While Job A is still running (sleeping), Job B launches with limit=cc_host_1
|
||||
- Both jobs set cacheable facts, but only for their respective limited host
|
||||
- After both complete, BOTH hosts should have populated facts
|
||||
|
||||
The bug: get_hosts_for_fact_cache() returns ALL inventory hosts regardless
|
||||
of --limit. start_fact_cache records them all in hosts_cached but writes
|
||||
no fact files (no prior facts). When the later-finishing job runs
|
||||
finish_fact_cache, it sees a missing fact file for the other job's host
|
||||
and clears that host's facts.
|
||||
"""
|
||||
inv = concurrent_facts_inventory
|
||||
jt = concurrent_facts_jt
|
||||
|
||||
# Launch Job A targeting cc_host_0
|
||||
job_a = jt.create_unified_job()
|
||||
job_a.limit = 'cc_host_0'
|
||||
job_a.save(update_fields=['limit'])
|
||||
job_a.signal_start()
|
||||
|
||||
# Wait for Job A to reach running (it will sleep inside the playbook)
|
||||
wait_to_leave_status(job_a, 'pending')
|
||||
wait_to_leave_status(job_a, 'waiting')
|
||||
logger.info(f'Job A (id={job_a.id}) is now running with limit=cc_host_0')
|
||||
|
||||
# Launch Job B targeting cc_host_1 while Job A is still running
|
||||
job_b = jt.create_unified_job()
|
||||
job_b.limit = 'cc_host_1'
|
||||
job_b.save(update_fields=['limit'])
|
||||
job_b.signal_start()
|
||||
|
||||
# Verify that Job A is still running when Job B starts,
|
||||
# otherwise the overlap that triggers the bug did not happen.
|
||||
wait_to_leave_status(job_b, 'pending')
|
||||
wait_to_leave_status(job_b, 'waiting')
|
||||
job_a.refresh_from_db()
|
||||
if job_a.status != 'running':
|
||||
pytest.skip('Job A finished before Job B started running; overlap did not occur')
|
||||
logger.info(f'Job B (id={job_b.id}) is now running with limit=cc_host_1 (concurrent with Job A)')
|
||||
|
||||
# Wait for both to complete
|
||||
wait_for_job(job_a)
|
||||
wait_for_job(job_b)
|
||||
|
||||
# Verify facts survived concurrent execution
|
||||
host_0 = inv.hosts.get(name='cc_host_0')
|
||||
host_1 = inv.hosts.get(name='cc_host_1')
|
||||
|
||||
# sanity
|
||||
job_a.refresh_from_db()
|
||||
job_b.refresh_from_db()
|
||||
assert job_a.limit == "cc_host_0"
|
||||
assert job_b.limit == "cc_host_1"
|
||||
|
||||
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
|
||||
assert discovered_foos == ['bar'] * 2, f'Unexpected facts on cc_host_0 or _1: {discovered_foos} after job a,b {job_a.id}, {job_b.id}'
|
||||
|
||||
|
||||
def test_concurrent_limit_does_not_revert_facts(live_tmp_folder, run_job_from_playbook, concurrent_facts_inventory):
|
||||
"""Concurrent jobs must not revert facts that a prior concurrent job just set.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
|
||||
Scenario:
|
||||
- First, populate both hosts with initial facts (foo=bar) via a
|
||||
non-concurrent gather run
|
||||
- Then run two concurrent jobs with different limits, each setting
|
||||
a new value (foo=bar_v2 via extra_vars)
|
||||
- After both complete, BOTH hosts should have foo=bar_v2
|
||||
|
||||
The bug: start_fact_cache writes the OLD facts (foo=bar) into each job's
|
||||
artifact dir for ALL hosts. If ansible's cache plugin rewrites a non-limited
|
||||
host's fact file with the stale content (updating the mtime), finish_fact_cache
|
||||
treats it as a legitimate update and overwrites the DB with old values.
|
||||
"""
|
||||
# --- Seed both hosts with initial facts via a non-concurrent run ---
|
||||
inv = concurrent_facts_inventory
|
||||
|
||||
scm_url = f'file://{live_tmp_folder}/facts'
|
||||
res = run_job_from_playbook(
|
||||
'seed_facts_for_revert_test',
|
||||
'gather_slow.yml',
|
||||
scm_url=scm_url,
|
||||
jt_params={'use_fact_cache': True, 'allow_simultaneous': True, 'inventory': inv.id},
|
||||
)
|
||||
for host in inv.hosts.all():
|
||||
assert host.ansible_facts.get('foo') == 'bar', f'Seed run failed to set facts on {host.name}: {host.ansible_facts}'
|
||||
|
||||
job = res['job']
|
||||
wait_for_job(job)
|
||||
|
||||
# sanity, jobs should be set up to both have facts with just bar
|
||||
host_0 = inv.hosts.get(name='cc_host_0')
|
||||
host_1 = inv.hosts.get(name='cc_host_1')
|
||||
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
|
||||
assert discovered_foos == ['bar'] * 2, f'Facts did not get expected initial values: {discovered_foos}'
|
||||
|
||||
jt = job.job_template
|
||||
assert jt.allow_simultaneous is True
|
||||
assert jt.use_fact_cache is True
|
||||
|
||||
# Sanity assertion, sometimes this would give problems from the Django rel cache
|
||||
assert jt.project
|
||||
|
||||
# --- Run two concurrent jobs that write a new value ---
|
||||
# Update the JT to pass extra_vars that change the fact value
|
||||
jt.extra_vars = '{"extra_value": "_v2"}'
|
||||
jt.save(update_fields=['extra_vars'])
|
||||
|
||||
job_a = jt.create_unified_job()
|
||||
job_a.limit = 'cc_host_0'
|
||||
job_a.save(update_fields=['limit'])
|
||||
job_a.signal_start()
|
||||
|
||||
wait_to_leave_status(job_a, 'pending')
|
||||
wait_to_leave_status(job_a, 'waiting')
|
||||
|
||||
job_b = jt.create_unified_job()
|
||||
job_b.limit = 'cc_host_1'
|
||||
job_b.save(update_fields=['limit'])
|
||||
job_b.signal_start()
|
||||
|
||||
wait_to_leave_status(job_b, 'pending')
|
||||
wait_to_leave_status(job_b, 'waiting')
|
||||
job_a.refresh_from_db()
|
||||
if job_a.status != 'running':
|
||||
pytest.skip('Job A finished before Job B started running; overlap did not occur')
|
||||
|
||||
wait_for_job(job_a)
|
||||
wait_for_job(job_b)
|
||||
|
||||
host_0 = inv.hosts.get(name='cc_host_0')
|
||||
host_1 = inv.hosts.get(name='cc_host_1')
|
||||
|
||||
# Both hosts should have the UPDATED value, not the old seed value
|
||||
discovered_foos = [host_0.ansible_facts.get('foo'), host_1.ansible_facts.get('foo')]
|
||||
assert discovered_foos == ['bar_v2'] * 2, f'Facts were reverted to stale values by concurrent job cc_host_0 or cc_host_1: {discovered_foos}'
|
||||
|
||||
|
||||
def test_fact_cache_scoped_to_inventory(live_tmp_folder, default_org, run_job_from_playbook):
|
||||
"""finish_fact_cache must not modify hosts in other inventories.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
|
||||
Bug: finish_fact_cache queries Host.objects.filter(name__in=host_names)
|
||||
without an inventory_id filter, so hosts with the same name in different
|
||||
inventories get their facts cross-contaminated.
|
||||
"""
|
||||
shared_name = 'scope_shared_host'
|
||||
|
||||
# Prepare for test by deleting junk from last run
|
||||
for inv_name in ('test_fact_scope_inv1', 'test_fact_scope_inv2'):
|
||||
inv = Inventory.objects.filter(name=inv_name).first()
|
||||
if inv:
|
||||
inv.delete()
|
||||
|
||||
inv1 = Inventory.objects.create(organization=default_org, name='test_fact_scope_inv1')
|
||||
inv1.hosts.create(name=shared_name)
|
||||
|
||||
inv2 = Inventory.objects.create(organization=default_org, name='test_fact_scope_inv2')
|
||||
host2 = inv2.hosts.create(name=shared_name)
|
||||
|
||||
# Give inv2's host distinct facts that should not be touched
|
||||
original_facts = {'source': 'inventory_2', 'untouched': True}
|
||||
host2.ansible_facts = original_facts
|
||||
host2.ansible_facts_modified = now()
|
||||
host2.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
# Run a fact-gathering job against inv1 only
|
||||
run_job_from_playbook(
|
||||
'test_fact_scope',
|
||||
'gather.yml',
|
||||
scm_url=f'file://{live_tmp_folder}/facts',
|
||||
jt_params={'use_fact_cache': True, 'inventory': inv1.id},
|
||||
)
|
||||
|
||||
# inv1's host should have facts
|
||||
host1 = inv1.hosts.get(name=shared_name)
|
||||
assert host1.ansible_facts, f'inv1 host should have facts after gather: {host1.ansible_facts}'
|
||||
|
||||
# inv2's host must NOT have been touched
|
||||
host2.refresh_from_db()
|
||||
assert host2.ansible_facts == original_facts, (
|
||||
f'Host in a different inventory was modified by a fact cache operation '
|
||||
f'on another inventory sharing the same hostname. '
|
||||
f'Expected {original_facts!r}, got {host2.ansible_facts!r}'
|
||||
)
|
||||
|
||||
|
||||
def test_constructed_inventory_facts_saved_to_source_host(live_tmp_folder, default_org, run_job_from_playbook):
|
||||
"""Facts from a constructed inventory job must be saved to the source host.
|
||||
|
||||
Generated by Claude Opus 4.6 (claude-opus-4-6).
|
||||
|
||||
Constructed inventories contain hosts that are references (via instance_id)
|
||||
to 'real' hosts in input inventories. start_fact_cache correctly resolves
|
||||
source hosts via get_hosts_for_fact_cache(), but finish_fact_cache must also
|
||||
write facts back to the source hosts, not the constructed inventory's copies.
|
||||
|
||||
Scenario:
|
||||
- Two input inventories each have a host named 'ci_shared_host'
|
||||
- A constructed inventory uses both as inputs
|
||||
- The inventory sync picks one source host (via instance_id) for the
|
||||
constructed host — which one depends on input processing order
|
||||
- Both source hosts start with distinct pre-existing facts
|
||||
- A fact-gathering job runs against the constructed inventory
|
||||
- After completion, the targeted source host should have the job's facts
|
||||
- The OTHER source host must retain its original facts untouched
|
||||
- The constructed host itself must NOT have facts stored on it
|
||||
(constructed hosts are transient — recreated on each inventory sync)
|
||||
"""
|
||||
shared_name = 'ci_shared_host'
|
||||
|
||||
# Cleanup from prior runs
|
||||
for inv_name in ('test_ci_facts_input1', 'test_ci_facts_input2', 'test_ci_facts_constructed'):
|
||||
Inventory.objects.filter(name=inv_name).delete()
|
||||
|
||||
# --- Create two input inventories, each with an identically-named host ---
|
||||
inv1 = Inventory.objects.create(organization=default_org, name='test_ci_facts_input1')
|
||||
source_host1 = inv1.hosts.create(name=shared_name)
|
||||
|
||||
inv2 = Inventory.objects.create(organization=default_org, name='test_ci_facts_input2')
|
||||
source_host2 = inv2.hosts.create(name=shared_name)
|
||||
|
||||
# Give both hosts distinct pre-existing facts so we can detect cross-contamination
|
||||
host1_original_facts = {'source': 'inventory_1'}
|
||||
source_host1.ansible_facts = host1_original_facts
|
||||
source_host1.ansible_facts_modified = now()
|
||||
source_host1.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
host2_original_facts = {'source': 'inventory_2'}
|
||||
source_host2.ansible_facts = host2_original_facts
|
||||
source_host2.ansible_facts_modified = now()
|
||||
source_host2.save(update_fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
source_hosts_by_id = {source_host1.id: source_host1, source_host2.id: source_host2}
|
||||
original_facts_by_id = {source_host1.id: host1_original_facts, source_host2.id: host2_original_facts}
|
||||
|
||||
# --- Create constructed inventory (sync will create hosts from inputs) ---
|
||||
constructed_inv = Inventory.objects.create(
|
||||
organization=default_org,
|
||||
name='test_ci_facts_constructed',
|
||||
kind='constructed',
|
||||
)
|
||||
constructed_inv.input_inventories.add(inv1)
|
||||
constructed_inv.input_inventories.add(inv2)
|
||||
|
||||
# --- Run a fact-gathering job against the constructed inventory ---
|
||||
# The job launch triggers an inventory sync which creates the constructed
|
||||
# host with instance_id pointing to one of the source hosts.
|
||||
scm_url = f'file://{live_tmp_folder}/facts'
|
||||
run_job_from_playbook(
|
||||
'test_ci_facts',
|
||||
'gather.yml',
|
||||
scm_url=scm_url,
|
||||
jt_params={'use_fact_cache': True, 'inventory': constructed_inv.id},
|
||||
)
|
||||
|
||||
# --- Determine which source host the constructed host points to ---
|
||||
constructed_host = constructed_inv.hosts.get(name=shared_name)
|
||||
target_id = int(constructed_host.instance_id)
|
||||
other_id = (set(source_hosts_by_id.keys()) - {target_id}).pop()
|
||||
|
||||
target_host = source_hosts_by_id[target_id]
|
||||
other_host = source_hosts_by_id[other_id]
|
||||
|
||||
target_host.refresh_from_db()
|
||||
other_host.refresh_from_db()
|
||||
constructed_host.refresh_from_db()
|
||||
|
||||
actual = [target_host.ansible_facts.get('foo'), other_host.ansible_facts, constructed_host.ansible_facts]
|
||||
expected = ['bar', original_facts_by_id[other_id], {}]
|
||||
assert actual == expected, (
|
||||
f'Constructed inventory fact cache wrote to wrong host(s). '
|
||||
f'target source host (id={target_id}) foo={actual[0]!r}, '
|
||||
f'other source host (id={other_id}) facts={actual[1]!r}, '
|
||||
f'constructed host facts={actual[2]!r}; expected {expected!r}'
|
||||
)
|
||||
@@ -1,347 +0,0 @@
|
||||
"""
|
||||
Integration tests for external query file functionality (AAP-58470).
|
||||
|
||||
Tests verify the end-to-end external query file workflow for indirect node
|
||||
counting using real AWX job execution. A fixture-created vendor collection
|
||||
at /var/lib/awx/vendor_collections/ provides external query files, simulating
|
||||
what the build-time (AAP-58426) and deployment (AAP-58557) integrations will
|
||||
provide once available.
|
||||
|
||||
Test data:
|
||||
- Collection 'demo.external' at various versions (no embedded query)
|
||||
- External query files in mock redhat.indirect_accounting collection
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import yaml
|
||||
|
||||
import pytest
|
||||
from flags.state import enable_flag, disable_flag, flag_enabled
|
||||
|
||||
from awx.main.tests.live.tests.conftest import wait_for_events, unified_job_stdout
|
||||
from awx.main.tasks.host_indirect import save_indirect_host_entries
|
||||
from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit
|
||||
from awx.main.models.event_query import EventQuery
|
||||
from awx.main.models import Job
|
||||
|
||||
# --- Constants ---
|
||||
|
||||
EXTERNAL_QUERY_JQ = '{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {device_type: .device_type}}'
|
||||
|
||||
EXTERNAL_QUERY_CONTENT = yaml.dump(
|
||||
{'demo.external.example': {'query': EXTERNAL_QUERY_JQ}},
|
||||
default_flow_style=False,
|
||||
)
|
||||
|
||||
# For precedence test: different jq (no device_type in facts) so we can detect which query was used
|
||||
EXTERNAL_QUERY_FOR_DEMO_QUERY_JQ = '{name: .name, canonical_facts: {host_name: .direct_host_name}, facts: {}}'
|
||||
|
||||
EXTERNAL_QUERY_FOR_DEMO_QUERY_CONTENT = yaml.dump(
|
||||
{'demo.query.example': {'query': EXTERNAL_QUERY_FOR_DEMO_QUERY_JQ}},
|
||||
default_flow_style=False,
|
||||
)
|
||||
|
||||
VENDOR_COLLECTIONS_BASE = '/var/lib/awx/vendor_collections'
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_indirect_host_counting():
|
||||
"""Enable FEATURE_INDIRECT_NODE_COUNTING_ENABLED flag for the test.
|
||||
|
||||
Only creates a FlagState DB record if the flag isn't already enabled
|
||||
(e.g. via development_defaults.py), to avoid UniqueViolation errors
|
||||
and to avoid leaking state to other tests.
|
||||
"""
|
||||
flag_name = "FEATURE_INDIRECT_NODE_COUNTING_ENABLED"
|
||||
was_enabled = flag_enabled(flag_name)
|
||||
if not was_enabled:
|
||||
enable_flag(flag_name)
|
||||
yield
|
||||
if not was_enabled:
|
||||
disable_flag(flag_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vendor_collections_dir():
|
||||
"""Set up mock redhat.indirect_accounting collection at /var/lib/awx/vendor_collections/.
|
||||
|
||||
Creates the collection structure with external query files:
|
||||
- demo.external.1.0.0.yml (exact match for v1.0.0)
|
||||
- demo.external.1.1.0.yml (fallback target for v1.5.0)
|
||||
- demo.query.0.0.1.yml (for precedence test with embedded-query collection)
|
||||
"""
|
||||
base = os.path.join(VENDOR_COLLECTIONS_BASE, 'ansible_collections', 'redhat', 'indirect_accounting')
|
||||
queries_path = os.path.join(base, 'extensions', 'audit', 'external_queries')
|
||||
meta_path = os.path.join(base, 'meta')
|
||||
|
||||
os.makedirs(queries_path, exist_ok=True)
|
||||
os.makedirs(meta_path, exist_ok=True)
|
||||
|
||||
# galaxy.yml for valid collection structure
|
||||
with open(os.path.join(base, 'galaxy.yml'), 'w') as f:
|
||||
yaml.dump(
|
||||
{
|
||||
'namespace': 'redhat',
|
||||
'name': 'indirect_accounting',
|
||||
'version': '1.0.0',
|
||||
'description': 'Test fixture for external query integration tests',
|
||||
'authors': ['AWX Tests'],
|
||||
'dependencies': {},
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
# meta/runtime.yml
|
||||
with open(os.path.join(meta_path, 'runtime.yml'), 'w') as f:
|
||||
yaml.dump({'requires_ansible': '>=2.15.0'}, f)
|
||||
|
||||
# External query files for demo.external collection
|
||||
for version in ('1.0.0', '1.1.0'):
|
||||
with open(os.path.join(queries_path, f'demo.external.{version}.yml'), 'w') as f:
|
||||
f.write(EXTERNAL_QUERY_CONTENT)
|
||||
|
||||
# External query file for demo.query collection (precedence test)
|
||||
with open(os.path.join(queries_path, 'demo.query.0.0.1.yml'), 'w') as f:
|
||||
f.write(EXTERNAL_QUERY_FOR_DEMO_QUERY_CONTENT)
|
||||
|
||||
yield base
|
||||
|
||||
# Cleanup: only remove the collection we created, not the entire vendor root
|
||||
shutil.rmtree(base, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_test_data():
|
||||
"""Clean up EventQuery and IndirectManagedNodeAudit records after each test."""
|
||||
yield
|
||||
EventQuery.objects.filter(fqcn='demo.external').delete()
|
||||
EventQuery.objects.filter(fqcn='demo.query').delete()
|
||||
IndirectManagedNodeAudit.objects.filter(job__name__icontains='external_query').delete()
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def run_external_query_job(run_job_from_playbook, live_tmp_folder, test_name, project_dir, jt_params=None):
|
||||
"""Run a job and return the Job object after waiting for indirect host processing."""
|
||||
scm_url = f'file://{live_tmp_folder}/{project_dir}'
|
||||
run_job_from_playbook(test_name, 'run_task.yml', scm_url=scm_url, jt_params=jt_params)
|
||||
|
||||
job = Job.objects.filter(name__icontains=test_name).order_by('-created').first()
|
||||
assert job is not None, f'Job not found for test {test_name}'
|
||||
wait_for_events(job)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
def wait_for_indirect_processing(job, expect_records=True, timeout=5):
|
||||
"""Wait for indirect host processing to complete.
|
||||
|
||||
Follows the same pattern as test_indirect_host_counting.py:53-72.
|
||||
"""
|
||||
# Ensure indirect host processing runs (wait_for_events already called by caller)
|
||||
job.refresh_from_db()
|
||||
if job.event_queries_processed is False:
|
||||
save_indirect_host_entries.delay(job.id, wait_for_events=False)
|
||||
|
||||
if expect_records:
|
||||
# Poll for audit records to appear
|
||||
for _ in range(20):
|
||||
if IndirectManagedNodeAudit.objects.filter(job=job).exists():
|
||||
break
|
||||
time.sleep(0.25)
|
||||
else:
|
||||
raise RuntimeError(f'No IndirectManagedNodeAudit records populated for job_id={job.id}')
|
||||
else:
|
||||
# For negative tests, wait a reasonable time to confirm no records appear
|
||||
time.sleep(timeout)
|
||||
job.refresh_from_db()
|
||||
|
||||
|
||||
# --- AC8.1: External query populates IndirectManagedNodeAudit correctly ---
|
||||
|
||||
|
||||
def test_external_query_populates_audit_table(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
|
||||
"""AC8.1: Job using demo.external.example with external query file populates
|
||||
IndirectManagedNodeAudit table correctly.
|
||||
|
||||
Uses demo.external v1.0.0 with exact-match external query file demo.external.1.0.0.yml.
|
||||
"""
|
||||
job = run_external_query_job(
|
||||
run_job_from_playbook,
|
||||
live_tmp_folder,
|
||||
'external_query_ac8_1',
|
||||
'test_host_query_external_v1_0_0',
|
||||
)
|
||||
wait_for_indirect_processing(job, expect_records=True)
|
||||
|
||||
# Verify installed_collections captured demo.external
|
||||
assert 'demo.external' in job.installed_collections
|
||||
assert 'host_query' in job.installed_collections['demo.external']
|
||||
|
||||
# Verify IndirectManagedNodeAudit records
|
||||
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1
|
||||
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
|
||||
|
||||
assert host_audit.canonical_facts == {'host_name': 'foo_host_default'}
|
||||
assert host_audit.facts == {'device_type': 'Fake Host'}
|
||||
assert host_audit.name == 'vm-foo'
|
||||
assert host_audit.organization == job.organization
|
||||
assert 'demo.external.example' in host_audit.events
|
||||
|
||||
|
||||
# --- AC8.2: Precedence - embedded query takes precedence over external ---
|
||||
|
||||
|
||||
def test_embedded_query_takes_precedence(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
|
||||
"""AC8.2: When collection has both embedded and external query files,
|
||||
the embedded query takes precedence.
|
||||
|
||||
Uses demo.query v0.0.1 which HAS an embedded query (extensions/audit/event_query.yml).
|
||||
An external query (demo.query.0.0.1.yml) also exists but uses a different jq expression
|
||||
(no device_type in facts). By checking the audit record's facts, we verify which query was used.
|
||||
"""
|
||||
# Run with demo.query collection (has embedded query)
|
||||
job = run_external_query_job(
|
||||
run_job_from_playbook,
|
||||
live_tmp_folder,
|
||||
'external_query_ac8_2',
|
||||
'test_host_query',
|
||||
)
|
||||
wait_for_indirect_processing(job, expect_records=True)
|
||||
|
||||
# Verify the embedded query was used (includes device_type in facts)
|
||||
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
|
||||
assert host_audit.facts == {'device_type': 'Fake Host'}, (
|
||||
'Expected embedded query output (with device_type). ' 'If facts is {}, the external query was incorrectly used instead.'
|
||||
)
|
||||
|
||||
|
||||
# --- AC8.3: Version fallback to compatible version ---
|
||||
|
||||
|
||||
def test_fallback_to_compatible_version(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
|
||||
"""AC8.3: Job using collection version with no exact query file falls back
|
||||
correctly to compatible version.
|
||||
|
||||
Uses demo.external v1.5.0. No demo.external.1.5.0.yml exists, but
|
||||
demo.external.1.1.0.yml is available (same major version, highest <= 1.5.0).
|
||||
The fallback should find and use the 1.1.0 query.
|
||||
"""
|
||||
job = run_external_query_job(
|
||||
run_job_from_playbook,
|
||||
live_tmp_folder,
|
||||
'external_query_ac8_3',
|
||||
'test_host_query_external_v1_5_0',
|
||||
)
|
||||
wait_for_indirect_processing(job, expect_records=True)
|
||||
|
||||
# Verify installed_collections captured demo.external at v1.5.0
|
||||
assert 'demo.external' in job.installed_collections
|
||||
assert job.installed_collections['demo.external']['version'] == '1.5.0'
|
||||
|
||||
# Verify IndirectManagedNodeAudit records were created via fallback
|
||||
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 1
|
||||
host_audit = IndirectManagedNodeAudit.objects.filter(job=job).first()
|
||||
|
||||
assert host_audit.canonical_facts == {'host_name': 'foo_host_default'}
|
||||
assert host_audit.facts == {'device_type': 'Fake Host'}
|
||||
assert host_audit.name == 'vm-foo'
|
||||
|
||||
|
||||
# --- AC8.4: Fallback queries don't overcount ---
|
||||
|
||||
|
||||
def test_fallback_does_not_overcount(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
|
||||
"""AC8.4: Fallback queries don't count MORE nodes than exact-version queries.
|
||||
|
||||
Runs two jobs:
|
||||
1. Exact match scenario (demo.external v1.0.0 -> demo.external.1.0.0.yml)
|
||||
2. Fallback scenario (demo.external v1.5.0 -> falls back to demo.external.1.1.0.yml)
|
||||
|
||||
Verifies that fallback record count <= exact record count.
|
||||
"""
|
||||
# Run exact-match job
|
||||
exact_job = run_external_query_job(
|
||||
run_job_from_playbook,
|
||||
live_tmp_folder,
|
||||
'external_query_ac8_4_exact',
|
||||
'test_host_query_external_v1_0_0',
|
||||
)
|
||||
wait_for_indirect_processing(exact_job, expect_records=True)
|
||||
exact_count = IndirectManagedNodeAudit.objects.filter(job=exact_job).count()
|
||||
|
||||
# Run fallback job
|
||||
fallback_job = run_external_query_job(
|
||||
run_job_from_playbook,
|
||||
live_tmp_folder,
|
||||
'external_query_ac8_4_fallback',
|
||||
'test_host_query_external_v1_5_0',
|
||||
)
|
||||
wait_for_indirect_processing(fallback_job, expect_records=True)
|
||||
fallback_count = IndirectManagedNodeAudit.objects.filter(job=fallback_job).count()
|
||||
|
||||
# Critical safety check: fallback must never count MORE than exact
|
||||
assert fallback_count <= exact_count, (
|
||||
f'Overcounting detected! Fallback produced {fallback_count} records ' f'but exact match produced only {exact_count} records.'
|
||||
)
|
||||
|
||||
# Both use the same jq expression and same module, so counts should be equal
|
||||
assert exact_count == fallback_count
|
||||
|
||||
|
||||
# --- AC8.5: Warning logs contain correct version information ---
|
||||
|
||||
|
||||
def test_fallback_log_contains_version_info(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
|
||||
"""AC8.5: Warning logs contain correct version information when fallback is used.
|
||||
|
||||
Runs a job with verbosity=1 so callback plugin verbose output is captured.
|
||||
Verifies the log contains the installed version (1.5.0), fallback version (1.1.0),
|
||||
and collection FQCN (demo.external).
|
||||
"""
|
||||
job = run_external_query_job(
|
||||
run_job_from_playbook,
|
||||
live_tmp_folder,
|
||||
'external_query_ac8_5',
|
||||
'test_host_query_external_v1_5_0',
|
||||
jt_params={'verbosity': 1},
|
||||
)
|
||||
wait_for_indirect_processing(job, expect_records=True)
|
||||
|
||||
# Get job stdout to check for fallback log message
|
||||
stdout = unified_job_stdout(job)
|
||||
|
||||
# The callback plugin emits: "Using external query {version_used} for {fqcn} v{ver}."
|
||||
assert '1.1.0' in stdout, f'Fallback version 1.1.0 not found in job stdout. stdout:\n{stdout}'
|
||||
assert 'demo.external' in stdout, f'Collection FQCN demo.external not found in job stdout. stdout:\n{stdout}'
|
||||
assert '1.5.0' in stdout, f'Installed version 1.5.0 not found in job stdout. stdout:\n{stdout}'
|
||||
|
||||
|
||||
# --- AC8.6: No counting when no compatible fallback exists ---
|
||||
|
||||
|
||||
def test_no_counting_without_compatible_fallback(live_tmp_folder, run_job_from_playbook, enable_indirect_host_counting, vendor_collections_dir):
|
||||
"""AC8.6: No counting occurs when no compatible fallback exists.
|
||||
|
||||
Uses demo.external v3.0.0 with only v1.x external query files available.
|
||||
Since major versions differ (3 vs 1), no fallback should occur and no
|
||||
IndirectManagedNodeAudit records should be created.
|
||||
"""
|
||||
job = run_external_query_job(
|
||||
run_job_from_playbook,
|
||||
live_tmp_folder,
|
||||
'external_query_ac8_6',
|
||||
'test_host_query_external_v3_0_0',
|
||||
)
|
||||
wait_for_indirect_processing(job, expect_records=False)
|
||||
|
||||
# No audit records should exist for this job
|
||||
assert IndirectManagedNodeAudit.objects.filter(job=job).count() == 0, (
|
||||
'IndirectManagedNodeAudit records were created despite no compatible ' 'fallback existing for demo.external v3.0.0 (only v1.x queries available).'
|
||||
)
|
||||
@@ -1,206 +0,0 @@
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from awx.main.tests.live.tests.conftest import wait_for_job
|
||||
|
||||
from awx.main.models import JobTemplate, WorkflowJobTemplate, WorkflowJobTemplateNode
|
||||
|
||||
JT_NAMES = ('artifact-test-first', 'artifact-test-second', 'artifact-test-reader')
|
||||
WFT_NAMES = ('artifact-test-outer-wf', 'artifact-test-inner-wf')
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_nested_workflow_set_stats_precedence(live_tmp_folder, demo_inv, project_factory, default_org):
|
||||
"""Reproducer for set_stats artifacts from an outer workflow leaking into
|
||||
an inner (child) workflow and overriding the inner workflow's own artifacts.
|
||||
|
||||
Outer WF: [job_first] --success--> [inner_wf]
|
||||
Inner WF: [job_second] --success--> [job_reader]
|
||||
|
||||
job_first sets via set_stats:
|
||||
var1: "outer-only" (only source, should propagate through)
|
||||
var2: "should-be-overridden" (will be overridden by job_second)
|
||||
|
||||
job_second sets via set_stats:
|
||||
var2: "from-inner" (should override outer's value)
|
||||
var3: "inner-only" (only source, should be available)
|
||||
|
||||
job_reader runs debug.yml (no set_stats), we inspect its extra_vars:
|
||||
var1 should be "outer-only" - outer artifacts propagate when uncontested
|
||||
var2 should be "from-inner" - inner artifacts override outer (THE BUG)
|
||||
var3 should be "inner-only" - inner-only artifacts propagate normally
|
||||
"""
|
||||
# Clean up resources from prior runs (delete individually for signals)
|
||||
for name in WFT_NAMES:
|
||||
for wft in WorkflowJobTemplate.objects.filter(name=name):
|
||||
wft.delete()
|
||||
for name in JT_NAMES:
|
||||
for jt in JobTemplate.objects.filter(name=name):
|
||||
jt.delete()
|
||||
|
||||
proj = project_factory(scm_url=f'file://{live_tmp_folder}/debug')
|
||||
if proj.current_job:
|
||||
wait_for_job(proj.current_job)
|
||||
|
||||
# job_first: sets var1 (outer-only) and var2 (to be overridden by inner)
|
||||
jt_first = JobTemplate.objects.create(
|
||||
name='artifact-test-first',
|
||||
project=proj,
|
||||
playbook='set_stats.yml',
|
||||
inventory=demo_inv,
|
||||
extra_vars=json.dumps({'stats_data': {'var1': 'outer-only', 'var2': 'should-be-overridden'}}),
|
||||
)
|
||||
# job_second: overrides var2, introduces var3
|
||||
jt_second = JobTemplate.objects.create(
|
||||
name='artifact-test-second',
|
||||
project=proj,
|
||||
playbook='set_stats.yml',
|
||||
inventory=demo_inv,
|
||||
extra_vars=json.dumps({'stats_data': {'var2': 'from-inner', 'var3': 'inner-only'}}),
|
||||
)
|
||||
# job_reader: just runs, we check what extra_vars it receives
|
||||
jt_reader = JobTemplate.objects.create(
|
||||
name='artifact-test-reader',
|
||||
project=proj,
|
||||
playbook='debug.yml',
|
||||
inventory=demo_inv,
|
||||
)
|
||||
|
||||
# Inner WFT: job_second -> job_reader
|
||||
inner_wft = WorkflowJobTemplate.objects.create(name='artifact-test-inner-wf', organization=default_org)
|
||||
inner_node_1 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=inner_wft,
|
||||
unified_job_template=jt_second,
|
||||
identifier='second',
|
||||
)
|
||||
inner_node_2 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=inner_wft,
|
||||
unified_job_template=jt_reader,
|
||||
identifier='reader',
|
||||
)
|
||||
inner_node_1.success_nodes.add(inner_node_2)
|
||||
|
||||
# Outer WFT: job_first -> inner_wf
|
||||
outer_wft = WorkflowJobTemplate.objects.create(name='artifact-test-outer-wf', organization=default_org)
|
||||
outer_node_1 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=outer_wft,
|
||||
unified_job_template=jt_first,
|
||||
identifier='first',
|
||||
)
|
||||
outer_node_2 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=outer_wft,
|
||||
unified_job_template=inner_wft,
|
||||
identifier='inner',
|
||||
)
|
||||
outer_node_1.success_nodes.add(outer_node_2)
|
||||
|
||||
# Launch and wait
|
||||
outer_wfj = outer_wft.create_unified_job()
|
||||
outer_wfj.signal_start()
|
||||
wait_for_job(outer_wfj, running_timeout=120)
|
||||
|
||||
# Find the reader job inside the inner workflow
|
||||
inner_wf_node = outer_wfj.workflow_job_nodes.get(identifier='inner')
|
||||
inner_wfj = inner_wf_node.job
|
||||
assert inner_wfj is not None, 'Inner workflow job was never created'
|
||||
|
||||
# Check that root node of inner WF (job_second) received outer artifacts
|
||||
second_node = inner_wfj.workflow_job_nodes.get(identifier='second')
|
||||
assert second_node.job is not None, 'Second job was never created'
|
||||
second_extra_vars = json.loads(second_node.job.extra_vars)
|
||||
assert second_extra_vars.get('var1') == 'outer-only', (
|
||||
f'Root node var1: expected "outer-only" (outer artifact should be available to root node), '
|
||||
f'got "{second_extra_vars.get("var1")}". '
|
||||
f'Outer artifacts are not reaching root nodes of child workflows.'
|
||||
)
|
||||
|
||||
reader_node = inner_wfj.workflow_job_nodes.get(identifier='reader')
|
||||
assert reader_node.job is not None, 'Reader job was never created'
|
||||
|
||||
reader_extra_vars = json.loads(reader_node.job.extra_vars)
|
||||
|
||||
# var1: only set by outer job_first, no conflict — should propagate through
|
||||
assert reader_extra_vars.get('var1') == 'outer-only', f'var1: expected "outer-only" (uncontested outer artifact), ' f'got "{reader_extra_vars.get("var1")}"'
|
||||
|
||||
# var2: set by outer as "should-be-overridden", then by inner as "from-inner"
|
||||
# Inner workflow's own ancestor artifacts should take precedence
|
||||
assert reader_extra_vars.get('var2') == 'from-inner', (
|
||||
f'var2: expected "from-inner" (inner workflow artifact should override outer), '
|
||||
f'got "{reader_extra_vars.get("var2")}". '
|
||||
f'Outer workflow artifacts are leaking via wj_special_vars. '
|
||||
f'reader node ancestor_artifacts={reader_node.ancestor_artifacts}'
|
||||
)
|
||||
|
||||
# var3: only set by inner job_second — should propagate normally
|
||||
assert reader_extra_vars.get('var3') == 'inner-only', f'var3: expected "inner-only" (inner-only artifact), ' f'got "{reader_extra_vars.get("var3")}"'
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_workflow_extra_vars_override_artifacts(live_tmp_folder, demo_inv, project_factory, default_org):
|
||||
"""Workflow extra_vars should take precedence over set_stats artifacts
|
||||
within a single (non-nested) workflow.
|
||||
|
||||
WF (extra_vars: my_var="from-wf-extra-vars"):
|
||||
[job_setter] --success--> [job_reader]
|
||||
|
||||
job_setter sets my_var="from-set-stats" via set_stats
|
||||
job_reader should see my_var="from-wf-extra-vars" because workflow
|
||||
extra_vars are higher precedence than ancestor artifacts.
|
||||
"""
|
||||
wft_name = 'artifact-test-wf-extra-vars-precedence'
|
||||
jt_names = ('artifact-test-setter', 'artifact-test-checker')
|
||||
|
||||
for wft in WorkflowJobTemplate.objects.filter(name=wft_name):
|
||||
wft.delete()
|
||||
for name in jt_names:
|
||||
for jt in JobTemplate.objects.filter(name=name):
|
||||
jt.delete()
|
||||
|
||||
proj = project_factory(scm_url=f'file://{live_tmp_folder}/debug')
|
||||
if proj.current_job:
|
||||
wait_for_job(proj.current_job)
|
||||
|
||||
jt_setter = JobTemplate.objects.create(
|
||||
name='artifact-test-setter',
|
||||
project=proj,
|
||||
playbook='set_stats.yml',
|
||||
inventory=demo_inv,
|
||||
extra_vars=json.dumps({'stats_data': {'my_var': 'from-set-stats'}}),
|
||||
)
|
||||
jt_checker = JobTemplate.objects.create(
|
||||
name='artifact-test-checker',
|
||||
project=proj,
|
||||
playbook='debug.yml',
|
||||
inventory=demo_inv,
|
||||
)
|
||||
|
||||
wft = WorkflowJobTemplate.objects.create(
|
||||
name=wft_name,
|
||||
organization=default_org,
|
||||
extra_vars=json.dumps({'my_var': 'from-wf-extra-vars'}),
|
||||
)
|
||||
node_1 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=wft,
|
||||
unified_job_template=jt_setter,
|
||||
identifier='setter',
|
||||
)
|
||||
node_2 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=wft,
|
||||
unified_job_template=jt_checker,
|
||||
identifier='checker',
|
||||
)
|
||||
node_1.success_nodes.add(node_2)
|
||||
|
||||
wfj = wft.create_unified_job()
|
||||
wfj.signal_start()
|
||||
wait_for_job(wfj, running_timeout=120)
|
||||
|
||||
checker_node = wfj.workflow_job_nodes.get(identifier='checker')
|
||||
assert checker_node.job is not None, 'Checker job was never created'
|
||||
|
||||
checker_extra_vars = json.loads(checker_node.job.extra_vars)
|
||||
assert checker_extra_vars.get('my_var') == 'from-wf-extra-vars', (
|
||||
f'Expected my_var="from-wf-extra-vars" (workflow extra_vars should override artifacts), '
|
||||
f'got my_var="{checker_extra_vars.get("my_var")}". '
|
||||
f'checker node ancestor_artifacts={checker_node.ancestor_artifacts}'
|
||||
)
|
||||
@@ -2,11 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.models import Credential, CredentialType
|
||||
from awx.main.models.credential import CredentialTypeHelper, ManagedCredentialType
|
||||
|
||||
from django.apps import apps
|
||||
|
||||
@@ -51,84 +47,3 @@ def test__get_credential_type_class_invalid_params():
|
||||
|
||||
assert type(e.value) is ValueError
|
||||
assert str(e.value) == 'Expected only apps or app_config to be defined, not both'
|
||||
|
||||
|
||||
def test_credential_context_property():
|
||||
"""Test that credential context property initializes empty dict and persists across accesses."""
|
||||
ct = CredentialType(name='Test Cred', kind='vault')
|
||||
cred = Credential(id=1, name='Test Credential', credential_type=ct, inputs={})
|
||||
|
||||
# First access should return empty dict
|
||||
context = cred.context
|
||||
assert context == {}
|
||||
|
||||
# Modify the context
|
||||
context['test_key'] = 'test_value'
|
||||
|
||||
# Second access should return the same dict with modifications
|
||||
assert cred.context == {'test_key': 'test_value'}
|
||||
assert cred.context is context # Same object reference
|
||||
|
||||
|
||||
def test_credential_context_property_independent_instances():
|
||||
"""Test that context property is independent between credential instances."""
|
||||
ct = CredentialType(name='Test Cred', kind='vault')
|
||||
cred1 = Credential(id=1, name='Cred 1', credential_type=ct, inputs={})
|
||||
cred2 = Credential(id=2, name='Cred 2', credential_type=ct, inputs={})
|
||||
|
||||
cred1.context['key1'] = 'value1'
|
||||
cred2.context['key2'] = 'value2'
|
||||
|
||||
assert cred1.context == {'key1': 'value1'}
|
||||
assert cred2.context == {'key2': 'value2'}
|
||||
assert cred1.context is not cred2.context
|
||||
|
||||
|
||||
def test_load_plugin_passes_description():
|
||||
plugin = SimpleNamespace(name='test_plugin', inputs={'fields': []}, backend=None, plugin_description='A test plugin')
|
||||
CredentialType.load_plugin('test_ns', plugin)
|
||||
entry = ManagedCredentialType.registry['test_ns']
|
||||
assert entry.description == 'A test plugin'
|
||||
del ManagedCredentialType.registry['test_ns']
|
||||
|
||||
|
||||
def test_load_plugin_missing_description():
|
||||
plugin = SimpleNamespace(name='test_plugin', inputs={'fields': []}, backend=None)
|
||||
CredentialType.load_plugin('test_ns', plugin)
|
||||
entry = ManagedCredentialType.registry['test_ns']
|
||||
assert entry.description == ''
|
||||
del ManagedCredentialType.registry['test_ns']
|
||||
|
||||
|
||||
def test_get_creation_params_external_includes_description():
|
||||
cred_type = SimpleNamespace(namespace='test_ns', kind='external', name='Test', description='My description')
|
||||
params = CredentialTypeHelper.get_creation_params(cred_type)
|
||||
assert params['description'] == 'My description'
|
||||
|
||||
|
||||
def test_get_creation_params_external_missing_description():
|
||||
cred_type = SimpleNamespace(namespace='test_ns', kind='external', name='Test')
|
||||
params = CredentialTypeHelper.get_creation_params(cred_type)
|
||||
assert params['description'] == ''
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_setup_tower_managed_defaults_updates_description():
|
||||
registry_entry = SimpleNamespace(
|
||||
namespace='test_ns',
|
||||
kind='external',
|
||||
name='Test Plugin',
|
||||
inputs={'fields': []},
|
||||
backend=None,
|
||||
description='Updated description',
|
||||
)
|
||||
# Create an existing credential type with no description
|
||||
ct = CredentialType.objects.create(name='Test Plugin', kind='external', namespace='old_ns')
|
||||
assert ct.description == ''
|
||||
|
||||
with mock.patch.dict(ManagedCredentialType.registry, {'test_ns': registry_entry}, clear=True):
|
||||
CredentialType._setup_tower_managed_defaults()
|
||||
|
||||
ct.refresh_from_db()
|
||||
assert ct.description == 'Updated description'
|
||||
assert ct.namespace == 'test_ns'
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.models import (
|
||||
Inventory,
|
||||
@@ -100,55 +99,52 @@ def test_start_job_fact_cache_within_timeout(hosts, tmpdir):
|
||||
|
||||
|
||||
def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
|
||||
artifacts_dir = str(tmpdir.mkdir("artifacts"))
|
||||
inventory_id = 5
|
||||
|
||||
start_fact_cache(hosts, artifacts_dir=artifacts_dir, timeout=0, inventory_id=inventory_id)
|
||||
|
||||
mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
|
||||
|
||||
# Remove the fact file for hosts[1] to simulate ansible's clear_facts
|
||||
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
|
||||
os.remove(os.path.join(fact_cache_dir, hosts[1].name))
|
||||
|
||||
hosts_qs = mock.MagicMock()
|
||||
hosts_qs.filter.return_value.order_by.return_value.iterator.return_value = iter(hosts)
|
||||
|
||||
finish_fact_cache(hosts_qs, artifacts_dir=artifacts_dir, inventory_id=inventory_id)
|
||||
|
||||
# hosts[1] should have had its facts cleared (file was missing, job_created=None)
|
||||
assert hosts[1].ansible_facts == {}
|
||||
assert hosts[1].ansible_facts_modified > ref_time
|
||||
|
||||
# Other hosts should be unmodified (fact files exist but weren't changed by ansible)
|
||||
for host in (hosts[0], hosts[2], hosts[3]):
|
||||
assert host.ansible_facts == {"a": 1, "b": 2}
|
||||
assert host.ansible_facts_modified == ref_time
|
||||
|
||||
|
||||
def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
|
||||
artifacts_dir = str(tmpdir.mkdir("artifacts"))
|
||||
inventory_id = 5
|
||||
|
||||
start_fact_cache(hosts, artifacts_dir=artifacts_dir, timeout=0, inventory_id=inventory_id)
|
||||
fact_cache = os.path.join(tmpdir, 'facts')
|
||||
start_fact_cache(hosts, fact_cache, timeout=0)
|
||||
|
||||
bulk_update = mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
|
||||
|
||||
# Overwrite fact files with invalid JSON and set future mtime
|
||||
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
|
||||
# Mock the os.path.exists behavior for host deletion
|
||||
# Let's assume the fact file for hosts[1] is missing.
|
||||
mocker.patch('os.path.exists', side_effect=lambda path: hosts[1].name not in path)
|
||||
|
||||
# Simulate one host's fact file getting deleted manually
|
||||
host_to_delete_filepath = os.path.join(fact_cache, hosts[1].name)
|
||||
|
||||
# Simulate the file being removed by checking existence first, to avoid FileNotFoundError
|
||||
if os.path.exists(host_to_delete_filepath):
|
||||
os.remove(host_to_delete_filepath)
|
||||
|
||||
finish_fact_cache(fact_cache)
|
||||
|
||||
# Simulate side effects that would normally be applied during bulk update
|
||||
hosts[1].ansible_facts = {}
|
||||
hosts[1].ansible_facts_modified = now()
|
||||
|
||||
# Verify facts are preserved for hosts with valid cache files
|
||||
for host in (hosts[0], hosts[2], hosts[3]):
|
||||
assert host.ansible_facts == {"a": 1, "b": 2}
|
||||
assert host.ansible_facts_modified == ref_time
|
||||
assert hosts[1].ansible_facts_modified > ref_time
|
||||
|
||||
# Current implementation skips the call entirely if hosts_to_update == []
|
||||
bulk_update.assert_not_called()
|
||||
|
||||
|
||||
def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
|
||||
fact_cache = os.path.join(tmpdir, 'facts')
|
||||
start_fact_cache(hosts, fact_cache, timeout=0)
|
||||
|
||||
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
|
||||
|
||||
for h in hosts:
|
||||
filepath = os.path.join(fact_cache_dir, h.name)
|
||||
filepath = os.path.join(fact_cache, h.name)
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('not valid json!')
|
||||
f.flush()
|
||||
new_modification_time = time.time() + 3600
|
||||
os.utime(filepath, (new_modification_time, new_modification_time))
|
||||
|
||||
hosts_qs = mock.MagicMock()
|
||||
hosts_qs.filter.return_value.order_by.return_value.iterator.return_value = iter(hosts)
|
||||
finish_fact_cache(fact_cache)
|
||||
|
||||
finish_fact_cache(hosts_qs, artifacts_dir=artifacts_dir, inventory_id=inventory_id)
|
||||
|
||||
# Invalid JSON should be skipped — no hosts updated
|
||||
updated_hosts = bulk_update.call_args[0][1]
|
||||
assert updated_hosts == []
|
||||
bulk_update.assert_not_called()
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
@@ -28,58 +32,20 @@ from ansible_base.lib.workload_identity.controller import AutomationControllerJo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def private_data_dir(tmp_path):
|
||||
private_data = tmp_path / 'awx_pdd'
|
||||
private_data.mkdir()
|
||||
def private_data_dir():
|
||||
private_data = tempfile.mkdtemp(prefix='awx_')
|
||||
for subfolder in ('inventory', 'env'):
|
||||
(private_data / subfolder).mkdir()
|
||||
return str(private_data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_template_with_credentials():
|
||||
"""
|
||||
Factory fixture that creates a job template with specified credentials.
|
||||
|
||||
Usage:
|
||||
job = job_template_with_credentials(ssh_cred, vault_cred)
|
||||
"""
|
||||
|
||||
def _create_job_template(
|
||||
*credentials, org_name='test-org', project_name='test-project', inventory_name='test-inventory', jt_name='test-jt', playbook='test.yml'
|
||||
):
|
||||
"""
|
||||
Create a job template with the given credentials.
|
||||
|
||||
Args:
|
||||
*credentials: Variable number of Credential objects to attach to the job template
|
||||
org_name: Name for the organization
|
||||
project_name: Name for the project
|
||||
inventory_name: Name for the inventory
|
||||
jt_name: Name for the job template
|
||||
playbook: Playbook filename
|
||||
|
||||
Returns:
|
||||
Job instance created from the job template
|
||||
"""
|
||||
org = Organization.objects.create(name=org_name)
|
||||
proj = Project.objects.create(name=project_name, organization=org)
|
||||
inv = Inventory.objects.create(name=inventory_name, organization=org)
|
||||
jt = JobTemplate.objects.create(name=jt_name, project=proj, inventory=inv, playbook=playbook)
|
||||
|
||||
if credentials:
|
||||
jt.credentials.add(*credentials)
|
||||
|
||||
return jt.create_unified_job()
|
||||
|
||||
return _create_job_template
|
||||
runner_subfolder = os.path.join(private_data, subfolder)
|
||||
os.makedirs(runner_subfolder, exist_ok=True)
|
||||
yield private_data
|
||||
shutil.rmtree(private_data, True)
|
||||
|
||||
|
||||
@mock.patch('awx.main.tasks.facts.settings')
|
||||
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
|
||||
def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
|
||||
# Create mocked inventory and host queryset
|
||||
inventory = mock.MagicMock(spec=Inventory, pk=1, kind='')
|
||||
inventory = mock.MagicMock(spec=Inventory, pk=1)
|
||||
host1 = mock.MagicMock(spec=Host, id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
|
||||
host2 = mock.MagicMock(spec=Host, id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
|
||||
|
||||
@@ -96,16 +62,12 @@ def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, pri
|
||||
proj = mock.MagicMock(spec=Project, pk=1, organization=org)
|
||||
job = mock.MagicMock(
|
||||
spec=Job,
|
||||
pk=1,
|
||||
id=1,
|
||||
use_fact_cache=True,
|
||||
project=proj,
|
||||
organization=org,
|
||||
job_slice_number=1,
|
||||
job_slice_count=1,
|
||||
inventory=inventory,
|
||||
inventory_id=inventory.pk,
|
||||
created=now(),
|
||||
execution_environment=execution_environment,
|
||||
)
|
||||
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
|
||||
@@ -137,11 +99,9 @@ def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, pri
|
||||
@mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
|
||||
@mock.patch('awx.main.tasks.facts.settings')
|
||||
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
|
||||
def test_pre_post_run_hook_facts_deleted_sliced(
|
||||
mock_create_partition, mock_facts_settings, mock_bulk_update_sorted_by_id, private_data_dir, execution_environment
|
||||
):
|
||||
def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
|
||||
# Fully mocked inventory
|
||||
mock_inventory = mock.MagicMock(spec=Inventory, pk=1, kind='')
|
||||
mock_inventory = mock.MagicMock(spec=Inventory)
|
||||
|
||||
# Create 999 mocked Host instances
|
||||
hosts = []
|
||||
@@ -167,8 +127,6 @@ def test_pre_post_run_hook_facts_deleted_sliced(
|
||||
|
||||
# Mock job object
|
||||
job = mock.MagicMock(spec=Job)
|
||||
job.pk = 2
|
||||
job.id = 2
|
||||
job.use_fact_cache = True
|
||||
job.project = proj
|
||||
job.organization = org
|
||||
@@ -176,8 +134,6 @@ def test_pre_post_run_hook_facts_deleted_sliced(
|
||||
job.job_slice_count = 3
|
||||
job.execution_environment = execution_environment
|
||||
job.inventory = mock_inventory
|
||||
job.inventory_id = mock_inventory.pk
|
||||
job.created = now()
|
||||
job.job_env.get.return_value = private_data_dir
|
||||
|
||||
# Bind actual method for host filtering
|
||||
@@ -471,186 +427,3 @@ def test_populate_claims_for_adhoc_command(workload_attrs, expected_claims):
|
||||
|
||||
claims = jobs.populate_claims_for_workload(adhoc_command)
|
||||
assert claims == expected_claims
|
||||
|
||||
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt returns the JWT string from the client."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.jwt = 'eyJ.test.jwt'
|
||||
mock_client.request_workload_jwt.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
unified_job = Job()
|
||||
unified_job.id = 42
|
||||
unified_job.name = 'Test Job'
|
||||
unified_job.launch_type = 'manual'
|
||||
unified_job.organization = Organization(id=1, name='Test Org')
|
||||
unified_job.unified_job_template = None
|
||||
unified_job.instance_group = None
|
||||
|
||||
result = jobs.retrieve_workload_identity_jwt(unified_job, audience='https://api.example.com', scope='aap_controller_automation_job')
|
||||
|
||||
assert result == 'eyJ.test.jwt'
|
||||
mock_client.request_workload_jwt.assert_called_once()
|
||||
call_kwargs = mock_client.request_workload_jwt.call_args[1]
|
||||
assert call_kwargs['audience'] == 'https://api.example.com'
|
||||
assert call_kwargs['scope'] == 'aap_controller_automation_job'
|
||||
assert 'claims' in call_kwargs
|
||||
assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_ID] == 42
|
||||
assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_NAME] == 'Test Job'
|
||||
|
||||
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt passes audience and scope to the client."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.request_workload_jwt.return_value = mock.MagicMock(jwt='token')
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
unified_job = mock.MagicMock()
|
||||
audience = 'custom_audience'
|
||||
scope = 'custom_scope'
|
||||
with mock.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 1}):
|
||||
jobs.retrieve_workload_identity_jwt(unified_job, audience=audience, scope=scope)
|
||||
|
||||
mock_client.request_workload_jwt.assert_called_once_with(claims={'job_id': 1}, scope=scope, audience=audience)
|
||||
|
||||
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt passes workload_ttl_seconds when provided."""
|
||||
mock_client = mock.Mock()
|
||||
mock_client.request_workload_jwt.return_value = mock.Mock(jwt='token')
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
unified_job = mock.MagicMock()
|
||||
with mock.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 1}):
|
||||
jobs.retrieve_workload_identity_jwt(
|
||||
unified_job,
|
||||
audience='https://vault.example.com',
|
||||
scope='aap_controller_automation_job',
|
||||
workload_ttl_seconds=3600,
|
||||
)
|
||||
|
||||
mock_client.request_workload_jwt.assert_called_once_with(
|
||||
claims={'job_id': 1},
|
||||
scope='aap_controller_automation_job',
|
||||
audience='https://vault.example.com',
|
||||
workload_ttl_seconds=3600,
|
||||
)
|
||||
|
||||
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt raises RuntimeError when client is None."""
|
||||
mock_get_client.return_value = None
|
||||
|
||||
unified_job = mock.MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Workload identity client is not configured"):
|
||||
jobs.retrieve_workload_identity_jwt(unified_job, audience='test_audience', scope='test_scope')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('effective_timeout,expected_ttl', [(3600, 3600), (0, None)])
|
||||
@mock.patch('awx.main.tasks.jobs.retrieve_workload_identity_jwt')
|
||||
@mock.patch('awx.main.tasks.jobs.flag_enabled', return_value=True)
|
||||
def test_populate_workload_identity_tokens_passes_get_instance_timeout_to_client(mock_flag_enabled, mock_retrieve_jwt, effective_timeout, expected_ttl):
|
||||
"""populate_workload_identity_tokens passes get_instance_timeout() value as workload_ttl_seconds to retrieve_workload_identity_jwt."""
|
||||
mock_retrieve_jwt.return_value = 'eyJ.test.jwt'
|
||||
|
||||
task = jobs.RunJob()
|
||||
task.instance = mock.MagicMock()
|
||||
|
||||
# Minimal credential with workload identity input source
|
||||
credential_ctx = {}
|
||||
input_src = mock.MagicMock()
|
||||
input_src.pk = 1
|
||||
input_src.source_credential = mock.MagicMock()
|
||||
input_src.source_credential.get_input.return_value = 'https://vault.example.com'
|
||||
input_src.source_credential.name = 'vault-cred'
|
||||
input_src.source_credential.credential_type = mock.MagicMock()
|
||||
input_src.source_credential.credential_type.inputs = {'fields': [{'id': 'workload_identity_token', 'internal': True}]}
|
||||
|
||||
credential = mock.MagicMock()
|
||||
credential.context = credential_ctx
|
||||
credential.input_sources = mock.MagicMock()
|
||||
credential.input_sources.all.return_value = [input_src]
|
||||
|
||||
task._credentials = [credential]
|
||||
|
||||
with mock.patch.object(task, 'get_instance_timeout', return_value=effective_timeout):
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
mock_flag_enabled.assert_called_once_with("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED")
|
||||
mock_retrieve_jwt.assert_called_once_with(
|
||||
task.instance,
|
||||
audience='https://vault.example.com',
|
||||
scope=AutomationControllerJobScope.name,
|
||||
workload_ttl_seconds=expected_ttl,
|
||||
)
|
||||
|
||||
|
||||
class TestRunInventoryUpdatePopulateWorkloadIdentityTokens:
|
||||
"""Tests for RunInventoryUpdate.populate_workload_identity_tokens."""
|
||||
|
||||
def test_cloud_credential_passed_as_additional_credential(self):
|
||||
"""The cloud credential is forwarded to super().populate_workload_identity_tokens via additional_credentials."""
|
||||
cloud_cred = mock.MagicMock(name='cloud_cred')
|
||||
cloud_cred.context = {}
|
||||
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = cloud_cred
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super:
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
mock_super.assert_called_once_with(additional_credentials=[cloud_cred])
|
||||
|
||||
def test_no_cloud_credential_calls_super_with_none(self):
|
||||
"""When there is no cloud credential, super() is called with additional_credentials=None."""
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = None
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super:
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
mock_super.assert_called_once_with(additional_credentials=None)
|
||||
|
||||
def test_additional_credentials_combined_with_cloud_credential(self):
|
||||
"""Caller-supplied additional_credentials are combined with the cloud credential."""
|
||||
cloud_cred = mock.MagicMock(name='cloud_cred')
|
||||
cloud_cred.context = {}
|
||||
extra_cred = mock.MagicMock(name='extra_cred')
|
||||
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = cloud_cred
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super:
|
||||
task.populate_workload_identity_tokens(additional_credentials=[extra_cred])
|
||||
|
||||
mock_super.assert_called_once_with(additional_credentials=[extra_cred, cloud_cred])
|
||||
|
||||
def test_cloud_credential_override_after_context_set(self):
|
||||
"""After OIDC processing, get_cloud_credential is overridden on the instance when context is populated."""
|
||||
cloud_cred = mock.MagicMock(name='cloud_cred')
|
||||
# Simulate that super().populate_workload_identity_tokens populates context
|
||||
cloud_cred.context = {'workload_identity_token': 'eyJ.test.jwt'}
|
||||
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = cloud_cred
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens'):
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
# The instance's get_cloud_credential should now return the same object with context
|
||||
assert task.instance.get_cloud_credential() is cloud_cred
|
||||
|
||||
@@ -76,9 +76,6 @@ def test_custom_error_messages(schema, given, message):
|
||||
({'fields': [{'id': 'token', 'label': 'Token', 'secret': 'bad'}]}, False),
|
||||
({'fields': [{'id': 'token', 'label': 'Token', 'ask_at_runtime': True}]}, True),
|
||||
({'fields': [{'id': 'token', 'label': 'Token', 'ask_at_runtime': 'bad'}]}, False), # noqa
|
||||
({'fields': [{'id': 'token', 'label': 'Token', 'internal': True}]}, True),
|
||||
({'fields': [{'id': 'token', 'label': 'Token', 'internal': False}]}, True),
|
||||
({'fields': [{'id': 'token', 'label': 'Token', 'internal': 'bad'}]}, False),
|
||||
({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': 'not-a-list'}]}, False), # noqa
|
||||
({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': []}]}, False),
|
||||
({'fields': [{'id': 'become_method', 'label': 'Become', 'choices': ['su', 'sudo']}]}, True), # noqa
|
||||
@@ -207,68 +204,6 @@ def test_credential_creation_validation_failure(inputs):
|
||||
assert e.type in (ValidationError, DRFValidationError)
|
||||
|
||||
|
||||
def test_credential_input_field_excludes_internal_fields():
|
||||
"""Internal fields should be excluded from the schema generated by CredentialInputField,
|
||||
preventing users from providing values for internally resolved fields."""
|
||||
type_ = CredentialType(
|
||||
kind='cloud',
|
||||
name='SomeCloud',
|
||||
managed=True,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'username', 'label': 'Username', 'type': 'string'},
|
||||
{'id': 'resolved_token', 'label': 'Token', 'type': 'string', 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
cred = Credential(credential_type=type_, name="Test Credential", inputs={'username': 'joe'})
|
||||
field = cred._meta.get_field('inputs')
|
||||
schema = field.schema(cred)
|
||||
|
||||
assert 'username' in schema['properties']
|
||||
assert 'resolved_token' not in schema['properties']
|
||||
|
||||
|
||||
def test_credential_input_field_rejects_values_for_internal_fields():
|
||||
"""Users should not be able to provide values for fields marked as internal."""
|
||||
type_ = CredentialType(
|
||||
kind='cloud',
|
||||
name='SomeCloud',
|
||||
managed=True,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'username', 'label': 'Username', 'type': 'string'},
|
||||
{'id': 'resolved_token', 'label': 'Token', 'type': 'string', 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
cred = Credential(credential_type=type_, name="Test Credential", inputs={'username': 'joe', 'resolved_token': 'secret'})
|
||||
field = cred._meta.get_field('inputs')
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
field.validate(cred.inputs, cred)
|
||||
assert e.type in (ValidationError, DRFValidationError)
|
||||
|
||||
|
||||
def test_credential_input_field_accepts_non_internal_fields_only():
|
||||
"""Credentials with only non-internal field values should validate successfully."""
|
||||
type_ = CredentialType(
|
||||
kind='cloud',
|
||||
name='SomeCloud',
|
||||
managed=True,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'username', 'label': 'Username', 'type': 'string'},
|
||||
{'id': 'resolved_token', 'label': 'Token', 'type': 'string', 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
cred = Credential(credential_type=type_, name="Test Credential", inputs={'username': 'joe'})
|
||||
field = cred._meta.get_field('inputs')
|
||||
# Should not raise
|
||||
field.validate(cred.inputs, cred)
|
||||
|
||||
|
||||
def test_implicit_role_field_parents():
|
||||
"""This assures that every ImplicitRoleField only references parents
|
||||
which are relationships that actually exist
|
||||
|
||||
@@ -1,431 +0,0 @@
|
||||
"""
|
||||
Unit tests for external query discovery and version fallback logic.
|
||||
Tests for AAP-58456: Unit Test Suite for External Query Handling
|
||||
"""
|
||||
|
||||
import sys
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
# Helper for mocking importlib.resources.files() path traversal
|
||||
def create_chainable_path_mock(final_mock, depth=3):
|
||||
"""Mock that supports chained / operations: mock / 'a' / 'b' / 'c' -> final_mock"""
|
||||
|
||||
class ChainableMock:
|
||||
def __init__(self, d=0):
|
||||
self.d = d
|
||||
|
||||
def __truediv__(self, other):
|
||||
return final_mock if self.d >= depth - 1 else ChainableMock(self.d + 1)
|
||||
|
||||
return ChainableMock()
|
||||
|
||||
|
||||
def create_queries_dir_mock(file_lookup_func):
|
||||
"""Mock for queries_dir: mock / 'filename' -> file_lookup_func('filename')"""
|
||||
|
||||
class QueriesDirMock:
|
||||
def __truediv__(self, filename):
|
||||
return file_lookup_func(filename)
|
||||
|
||||
return QueriesDirMock()
|
||||
|
||||
|
||||
# Ansible mocking required for importing the module (it imports from ansible.plugins.callback.CallbackBase)
|
||||
class MockCallbackBase:
|
||||
def __init__(self):
|
||||
self._display = mock.MagicMock()
|
||||
|
||||
def v2_playbook_on_stats(self, stats):
|
||||
pass
|
||||
|
||||
|
||||
_mock_callback_module = mock.MagicMock()
|
||||
_mock_callback_module.CallbackBase = MockCallbackBase
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_ansible_modules():
|
||||
"""Temporarily inject fake ansible modules so the callback plugin can be imported."""
|
||||
with mock.patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
'ansible': mock.MagicMock(),
|
||||
'ansible.plugins': mock.MagicMock(),
|
||||
'ansible.plugins.callback': _mock_callback_module,
|
||||
'ansible.cli': mock.MagicMock(),
|
||||
'ansible.cli.galaxy': mock.MagicMock(),
|
||||
'ansible.release': mock.MagicMock(__version__='2.16.0'),
|
||||
'ansible.galaxy': mock.MagicMock(),
|
||||
'ansible.galaxy.collection': mock.MagicMock(),
|
||||
'ansible.utils': mock.MagicMock(),
|
||||
'ansible.utils.collection_loader': mock.MagicMock(),
|
||||
'ansible.constants': mock.MagicMock(),
|
||||
},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestListExternalQueries:
|
||||
"""Tests for list_external_queries function."""
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
def test_returns_empty_when_collection_not_installed(self, mock_files):
|
||||
from awx.playbooks.library.indirect_instance_count import list_external_queries
|
||||
|
||||
mock_files.side_effect = ModuleNotFoundError("No module named 'ansible_collections.redhat'")
|
||||
|
||||
result = list_external_queries('demo', 'external')
|
||||
|
||||
assert result == []
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
def test_parses_version_from_filenames(self, mock_files):
|
||||
from awx.playbooks.library.indirect_instance_count import list_external_queries
|
||||
|
||||
mock_file_1 = mock.Mock()
|
||||
mock_file_1.name = 'demo.external.1.0.0.yml'
|
||||
mock_file_2 = mock.Mock()
|
||||
mock_file_2.name = 'demo.external.2.1.0.yml'
|
||||
mock_file_other = mock.Mock()
|
||||
mock_file_other.name = 'other.collection.1.0.0.yml'
|
||||
|
||||
mock_queries_dir = mock.Mock()
|
||||
mock_queries_dir.iterdir.return_value = [mock_file_1, mock_file_2, mock_file_other]
|
||||
mock_files.return_value = create_chainable_path_mock(mock_queries_dir)
|
||||
|
||||
result = list_external_queries('demo', 'external')
|
||||
|
||||
assert len(result) == 2
|
||||
assert Version('1.0.0') in result
|
||||
assert Version('2.1.0') in result
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
def test_skips_invalid_versions(self, mock_files):
|
||||
from awx.playbooks.library.indirect_instance_count import list_external_queries
|
||||
|
||||
mock_file_valid = mock.Mock()
|
||||
mock_file_valid.name = 'demo.external.1.0.0.yml'
|
||||
mock_file_invalid = mock.Mock()
|
||||
mock_file_invalid.name = 'demo.external.invalid.yml'
|
||||
|
||||
mock_queries_dir = mock.Mock()
|
||||
mock_queries_dir.iterdir.return_value = [mock_file_valid, mock_file_invalid]
|
||||
mock_files.return_value = create_chainable_path_mock(mock_queries_dir)
|
||||
|
||||
result = list_external_queries('demo', 'external')
|
||||
|
||||
assert len(result) == 1
|
||||
assert Version('1.0.0') in result
|
||||
|
||||
|
||||
class TestVersionFallback:
|
||||
"""Tests for version fallback logic (AC7.4-AC7.9)."""
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count._get_query_file_dir')
|
||||
def test_exact_match_preferred(self, mock_get_dir):
|
||||
"""AC7.4: Exact version match is preferred over fallback version."""
|
||||
from awx.playbooks.library.indirect_instance_count import find_external_query_with_fallback
|
||||
|
||||
mock_exact_file = mock.Mock()
|
||||
mock_exact_file.exists.return_value = True
|
||||
mock_exact_file.open.return_value.__enter__ = mock.Mock(return_value=StringIO('exact_version_query'))
|
||||
mock_exact_file.open.return_value.__exit__ = mock.Mock(return_value=False)
|
||||
|
||||
mock_get_dir.return_value = create_queries_dir_mock(lambda f: mock_exact_file)
|
||||
|
||||
content, fallback_used, version = find_external_query_with_fallback('demo', 'external', '2.5.0')
|
||||
|
||||
assert content == 'exact_version_query'
|
||||
assert fallback_used is False
|
||||
assert version == '2.5.0'
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_external_queries')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count._get_query_file_dir')
|
||||
def test_fallback_nearest_lower_same_major(self, mock_get_dir, mock_list):
|
||||
"""AC7.5: Fallback selects nearest lower version within same major version.
|
||||
|
||||
When installed is 4.5.0 and 4.0.0/4.1.0 are available, selects 4.1.0.
|
||||
"""
|
||||
from awx.playbooks.library.indirect_instance_count import find_external_query_with_fallback
|
||||
|
||||
mock_list.return_value = [Version('4.0.0'), Version('4.1.0')]
|
||||
|
||||
mock_exact_file = mock.Mock(exists=mock.Mock(return_value=False))
|
||||
mock_fallback_file = mock.Mock()
|
||||
mock_fallback_file.exists.return_value = True
|
||||
mock_fallback_file.open.return_value.__enter__ = mock.Mock(return_value=StringIO('fallback_query'))
|
||||
mock_fallback_file.open.return_value.__exit__ = mock.Mock(return_value=False)
|
||||
|
||||
def file_lookup(filename):
|
||||
return mock_fallback_file if '4.1.0' in filename else mock_exact_file
|
||||
|
||||
mock_get_dir.return_value = create_queries_dir_mock(file_lookup)
|
||||
|
||||
content, fallback_used, version = find_external_query_with_fallback('community', 'vmware', '4.5.0')
|
||||
|
||||
assert content == 'fallback_query'
|
||||
assert fallback_used is True
|
||||
assert version == '4.1.0'
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_external_queries')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count._get_query_file_dir')
|
||||
def test_fallback_respects_major_version_boundary(self, mock_get_dir, mock_list):
|
||||
"""Test that fallback does NOT cross major version boundaries.
|
||||
|
||||
When installed version is 6.0.0 and only 5.0.0 query exists,
|
||||
no fallback should occur because major versions differ.
|
||||
"""
|
||||
from awx.playbooks.library.indirect_instance_count import find_external_query_with_fallback
|
||||
|
||||
mock_list.return_value = [Version('5.0.0')]
|
||||
|
||||
# Mock exact file (6.0.0) to not exist
|
||||
mock_exact_file = mock.Mock(exists=mock.Mock(return_value=False))
|
||||
# Mock fallback file (5.0.0) to exist - if major version check is broken,
|
||||
# this file would be incorrectly selected
|
||||
mock_fallback_file = mock.Mock()
|
||||
mock_fallback_file.exists.return_value = True
|
||||
mock_fallback_file.open.return_value.__enter__ = mock.Mock(return_value=StringIO('wrong_major_version_query'))
|
||||
mock_fallback_file.open.return_value.__exit__ = mock.Mock(return_value=False)
|
||||
|
||||
def file_lookup(filename):
|
||||
return mock_fallback_file if '5.0.0' in filename else mock_exact_file
|
||||
|
||||
mock_get_dir.return_value = create_queries_dir_mock(file_lookup)
|
||||
|
||||
content, fallback_used, version = find_external_query_with_fallback('community', 'vmware', '6.0.0')
|
||||
|
||||
# Should NOT fall back to 5.0.0 because major version differs (5 vs 6)
|
||||
assert content is None
|
||||
assert fallback_used is False
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_external_queries')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count._get_query_file_dir')
|
||||
def test_no_fallback_when_incompatible(self, mock_get_dir, mock_list):
|
||||
"""AC7.7: No fallback when all available versions are higher than installed.
|
||||
|
||||
When installed version is 3.8.0 and only 4.0.0 and 5.0.0 exist,
|
||||
no fallback should occur because both are higher than installed.
|
||||
"""
|
||||
from awx.playbooks.library.indirect_instance_count import find_external_query_with_fallback
|
||||
|
||||
mock_list.return_value = [Version('4.0.0'), Version('5.0.0')]
|
||||
|
||||
# Mock exact file (3.8.0) to not exist
|
||||
mock_exact_file = mock.Mock(exists=mock.Mock(return_value=False))
|
||||
# Mock available files to exist - if version filtering is broken,
|
||||
# one of these would be incorrectly selected
|
||||
mock_available_file = mock.Mock()
|
||||
mock_available_file.exists.return_value = True
|
||||
mock_available_file.open.return_value.__enter__ = mock.Mock(return_value=StringIO('higher_version_query'))
|
||||
mock_available_file.open.return_value.__exit__ = mock.Mock(return_value=False)
|
||||
|
||||
def file_lookup(filename):
|
||||
if '4.0.0' in filename or '5.0.0' in filename:
|
||||
return mock_available_file
|
||||
return mock_exact_file
|
||||
|
||||
mock_get_dir.return_value = create_queries_dir_mock(file_lookup)
|
||||
|
||||
content, fallback_used, version = find_external_query_with_fallback('community', 'vmware', '3.8.0')
|
||||
|
||||
# Should NOT fall back to 4.0.0 or 5.0.0 because both are higher than 3.8.0
|
||||
assert content is None
|
||||
assert fallback_used is False
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_external_queries')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count._get_query_file_dir')
|
||||
def test_fallback_selection_logic(self, mock_get_dir, mock_list):
|
||||
"""AC7.9: Complex fallback scenario with multiple candidates.
|
||||
|
||||
When installed is 4.5.0 and 4.0.0, 4.1.0, 5.0.0 are available,
|
||||
selects 4.1.0 (highest compatible within same major, <= installed).
|
||||
"""
|
||||
from awx.playbooks.library.indirect_instance_count import find_external_query_with_fallback
|
||||
|
||||
mock_list.return_value = [Version('4.0.0'), Version('4.1.0'), Version('5.0.0')]
|
||||
|
||||
mock_exact_file = mock.Mock(exists=mock.Mock(return_value=False))
|
||||
mock_fallback_file = mock.Mock()
|
||||
mock_fallback_file.exists.return_value = True
|
||||
mock_fallback_file.open.return_value.__enter__ = mock.Mock(return_value=StringIO('query_4.1.0'))
|
||||
mock_fallback_file.open.return_value.__exit__ = mock.Mock(return_value=False)
|
||||
|
||||
def file_lookup(filename):
|
||||
return mock_fallback_file if '4.1.0' in filename else mock_exact_file
|
||||
|
||||
mock_get_dir.return_value = create_queries_dir_mock(file_lookup)
|
||||
|
||||
content, fallback_used, version = find_external_query_with_fallback('community', 'vmware', '4.5.0')
|
||||
|
||||
assert version == '4.1.0'
|
||||
assert fallback_used is True
|
||||
assert content == 'query_4.1.0'
|
||||
|
||||
|
||||
class TestExternalQueryDiscovery:
|
||||
"""Tests for callback plugin query discovery (AC7.1-AC7.3)."""
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_collections')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.find_external_query_with_fallback')
|
||||
@mock.patch.dict('os.environ', {'AWX_ISOLATED_DATA_DIR': '/tmp/artifacts'})
|
||||
def test_precedence_embedded_over_external(self, mock_fallback, mock_files, mock_list_collections):
|
||||
"""AC7.1: Embedded query takes precedence when both embedded and external exist."""
|
||||
from awx.playbooks.library.indirect_instance_count import CallbackModule
|
||||
|
||||
mock_list_collections.return_value = [mock.Mock(namespace='demo', name='query', ver='1.0.0', fqcn='demo.query')]
|
||||
|
||||
mock_embedded_file = mock.Mock()
|
||||
mock_embedded_file.exists.return_value = True
|
||||
mock_embedded_file.open.return_value.__enter__ = mock.Mock(return_value=StringIO('embedded_query'))
|
||||
mock_embedded_file.open.return_value.__exit__ = mock.Mock(return_value=False)
|
||||
mock_files.return_value = create_chainable_path_mock(mock_embedded_file)
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
callback.v2_playbook_on_stats(mock.Mock())
|
||||
|
||||
mock_fallback.assert_not_called()
|
||||
callback._display.vv.assert_called()
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_collections')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.find_external_query_with_fallback')
|
||||
@mock.patch.dict('os.environ', {'AWX_ISOLATED_DATA_DIR': '/tmp/artifacts'})
|
||||
def test_external_query_when_embedded_missing(self, mock_fallback, mock_files, mock_list_collections):
|
||||
"""AC7.2: External query is discovered when embedded query is missing."""
|
||||
from awx.playbooks.library.indirect_instance_count import CallbackModule
|
||||
|
||||
mock_candidate = mock.Mock()
|
||||
mock_candidate.namespace = 'demo'
|
||||
mock_candidate.name = 'external'
|
||||
mock_candidate.ver = '2.5.0'
|
||||
mock_candidate.fqcn = 'demo.external'
|
||||
mock_list_collections.return_value = [mock_candidate]
|
||||
|
||||
mock_embedded_file = mock.Mock(exists=mock.Mock(return_value=False))
|
||||
mock_files.return_value = create_chainable_path_mock(mock_embedded_file)
|
||||
mock_fallback.return_value = ('external_query_content', False, '2.5.0')
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
callback.v2_playbook_on_stats(mock.Mock())
|
||||
|
||||
mock_fallback.assert_called_once_with('demo', 'external', '2.5.0')
|
||||
callback._display.v.assert_called()
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_collections')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.find_external_query_with_fallback')
|
||||
@mock.patch.dict('os.environ', {'AWX_ISOLATED_DATA_DIR': '/tmp/artifacts'})
|
||||
def test_no_query_when_both_missing(self, mock_fallback, mock_files, mock_list_collections):
|
||||
"""AC7.3: No query is used when both embedded and external queries are missing."""
|
||||
from awx.playbooks.library.indirect_instance_count import CallbackModule
|
||||
|
||||
mock_list_collections.return_value = [mock.Mock(namespace='unknown', name='collection', ver='1.0.0', fqcn='unknown.collection')]
|
||||
|
||||
mock_embedded_file = mock.Mock(exists=mock.Mock(return_value=False))
|
||||
mock_files.return_value = create_chainable_path_mock(mock_embedded_file)
|
||||
mock_fallback.return_value = (None, False, None)
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
callback.v2_playbook_on_stats(mock.Mock())
|
||||
|
||||
mock_fallback.assert_called_once()
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_collections')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.find_external_query_with_fallback')
|
||||
@mock.patch.dict('os.environ', {'AWX_ISOLATED_DATA_DIR': '/tmp/artifacts'})
|
||||
def test_info_log_on_fallback(self, mock_fallback, mock_files, mock_list_collections):
|
||||
"""AC7.8: Log message is emitted when fallback version is used.
|
||||
|
||||
Verifies that when a fallback version is used, a log message is emitted
|
||||
containing both the fallback version and the collection FQCN.
|
||||
|
||||
Note: AC7.8 specifies 'warning logs' but implementation uses verbose/info
|
||||
level (_display.v) as this is informational rather than a warning condition.
|
||||
"""
|
||||
from awx.playbooks.library.indirect_instance_count import CallbackModule
|
||||
|
||||
mock_list_collections.return_value = [mock.Mock(namespace='community', name='vmware', ver='4.5.0', fqcn='community.vmware')]
|
||||
|
||||
mock_embedded_file = mock.Mock(exists=mock.Mock(return_value=False))
|
||||
mock_files.return_value = create_chainable_path_mock(mock_embedded_file)
|
||||
mock_fallback.return_value = ('fallback_query_content', True, '4.1.0')
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
callback.v2_playbook_on_stats(mock.Mock())
|
||||
|
||||
callback._display.v.assert_called()
|
||||
call_args = callback._display.v.call_args[0][0]
|
||||
assert '4.1.0' in call_args
|
||||
assert 'community.vmware' in call_args
|
||||
|
||||
|
||||
class TestPrivateDataDirIntegration:
|
||||
"""Tests for vendor collection copying (AC7.10-AC7.11)."""
|
||||
|
||||
@mock.patch('awx.main.tasks.jobs.flag_enabled')
|
||||
@mock.patch('awx.main.tasks.jobs.shutil.copytree')
|
||||
@mock.patch('awx.main.tasks.jobs.os.path.exists')
|
||||
def test_vendor_collections_copied(self, mock_exists, mock_copytree, mock_flag):
|
||||
"""AC7.10: build_private_data_files() copies vendor collections to private_data_dir."""
|
||||
from awx.main.tasks.jobs import BaseTask
|
||||
|
||||
mock_flag.return_value = True
|
||||
mock_exists.return_value = True
|
||||
|
||||
task = BaseTask()
|
||||
task.instance = mock.Mock()
|
||||
task.cleanup_paths = []
|
||||
task.build_private_data = mock.Mock(return_value=None)
|
||||
|
||||
private_data_dir = '/tmp/awx_123_abc'
|
||||
task.build_private_data_files(task.instance, private_data_dir)
|
||||
|
||||
mock_copytree.assert_called_once_with('/var/lib/awx/vendor_collections', f'{private_data_dir}/vendor_collections')
|
||||
|
||||
@mock.patch('awx.main.tasks.jobs.flag_enabled')
|
||||
@mock.patch('awx.main.tasks.jobs.logger')
|
||||
@mock.patch('awx.main.tasks.jobs.shutil.copytree')
|
||||
@mock.patch('awx.main.tasks.jobs.os.path.exists')
|
||||
def test_missing_source_handled_gracefully(self, mock_exists, mock_copytree, mock_logger, mock_flag):
|
||||
"""AC7.11: Collection copy handles missing source directory gracefully."""
|
||||
from awx.main.tasks.jobs import BaseTask
|
||||
|
||||
mock_flag.return_value = True
|
||||
mock_exists.return_value = False
|
||||
|
||||
task = BaseTask()
|
||||
task.instance = mock.Mock()
|
||||
task.cleanup_paths = []
|
||||
task.build_private_data = mock.Mock(return_value=None)
|
||||
|
||||
private_data_dir = '/tmp/awx_123_abc'
|
||||
result = task.build_private_data_files(task.instance, private_data_dir)
|
||||
|
||||
# copytree should not be called when source doesn't exist
|
||||
mock_copytree.assert_not_called()
|
||||
# Function should complete without raising an exception
|
||||
assert result is not None
|
||||
@@ -1,6 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import fcntl
|
||||
@@ -58,12 +60,14 @@ class TestJobExecution(object):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def private_data_dir(tmp_path):
|
||||
private_data = tmp_path / 'awx_pdd'
|
||||
private_data.mkdir()
|
||||
def private_data_dir():
|
||||
private_data = tempfile.mkdtemp(prefix='awx_')
|
||||
for subfolder in ('inventory', 'env'):
|
||||
(private_data / subfolder).mkdir()
|
||||
return str(private_data)
|
||||
runner_subfolder = os.path.join(private_data, subfolder)
|
||||
if not os.path.exists(runner_subfolder):
|
||||
os.mkdir(runner_subfolder)
|
||||
yield private_data
|
||||
shutil.rmtree(private_data, True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -552,8 +556,7 @@ class TestGenericRun:
|
||||
task._write_extra_vars_file = mock.Mock()
|
||||
|
||||
with mock.patch('awx.main.tasks.jobs.settings.AWX_TASK_ENV', {'FOO': 'BAR'}):
|
||||
with mock.patch.object(task, 'build_credentials_list', return_value=[], autospec=True):
|
||||
env = task.build_env(job, private_data_dir)
|
||||
env = task.build_env(job, private_data_dir)
|
||||
assert env['FOO'] == 'BAR'
|
||||
|
||||
|
||||
@@ -621,11 +624,6 @@ class TestAdhocRun(TestJobExecution):
|
||||
|
||||
|
||||
class TestJobCredentials(TestJobExecution):
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_flag_enabled(self):
|
||||
with mock.patch('awx.main.tasks.jobs.flag_enabled', return_value=False):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def job(self, execution_environment):
|
||||
job = Job(pk=1, inventory=Inventory(pk=1), project=Project(pk=1))
|
||||
@@ -651,9 +649,7 @@ class TestJobCredentials(TestJobExecution):
|
||||
)
|
||||
|
||||
with mock.patch.object(UnifiedJob, 'credentials', credentials_mock):
|
||||
# Mock build_credentials_list to work with the cached credentials mechanism
|
||||
with mock.patch.object(jobs.RunJob, 'build_credentials_list', return_value=job._credentials, autospec=True):
|
||||
yield job
|
||||
yield job
|
||||
|
||||
@pytest.fixture
|
||||
def update_model_wrapper(self, job):
|
||||
@@ -1159,11 +1155,6 @@ class TestProjectUpdateRefspec(TestJobExecution):
|
||||
|
||||
|
||||
class TestInventoryUpdateCredentials(TestJobExecution):
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_flag_enabled(self):
|
||||
with mock.patch('awx.main.tasks.jobs.flag_enabled', return_value=False):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def inventory_update(self, execution_environment):
|
||||
return InventoryUpdate(pk=1, execution_environment=execution_environment, inventory_source=InventorySource(pk=1, inventory=Inventory(pk=1)))
|
||||
@@ -1583,7 +1574,7 @@ def test_managed_injector_redaction(injector_cls):
|
||||
assert 'very_secret_value' not in str(build_safe_env(env))
|
||||
|
||||
|
||||
def test_job_run_no_ee(mock_me, mock_create_partition, private_data_dir):
|
||||
def test_job_run_no_ee(mock_me, mock_create_partition):
|
||||
org = Organization(pk=1)
|
||||
proj = Project(pk=1, organization=org)
|
||||
job = Job(project=proj, organization=org, inventory=Inventory(pk=1))
|
||||
|
||||
@@ -330,13 +330,17 @@ class TestHostnameRegexValidator:
|
||||
|
||||
def test_bad_call(self, regex_expr, re_flags):
|
||||
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags)
|
||||
with pytest.raises(ValidationError, match=r"^\['illegal characters detected in hostname=@#\$%\)\$#\(TUFAS_DG. Please verify.'\]$"):
|
||||
try:
|
||||
h("@#$%)$#(TUFAS_DG")
|
||||
except ValidationError as e:
|
||||
assert e.message is not None
|
||||
|
||||
def test_good_call_with_inverse(self, regex_expr, re_flags, inverse_match=True):
|
||||
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags, inverse_match=inverse_match)
|
||||
with pytest.raises(ValidationError, match=r"^\['Enter a valid value.'\]$"):
|
||||
try:
|
||||
h("1.2.3.4")
|
||||
except ValidationError as e:
|
||||
assert e.message is not None
|
||||
|
||||
def test_bad_call_with_inverse(self, regex_expr, re_flags, inverse_match=True):
|
||||
h = HostnameRegexValidator(regex=regex_expr, flags=re_flags, inverse_match=inverse_match)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import shutil
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -22,11 +24,17 @@ def test_switch_paths(container_path, host_path):
|
||||
assert get_incontainer_path(host_path, private_data_dir) == container_path
|
||||
|
||||
|
||||
def test_symlink_isolation_dir(tmp_path):
|
||||
src_path = tmp_path / 'symlink_src'
|
||||
dst_path = tmp_path / 'symlink_dst'
|
||||
def test_symlink_isolation_dir(request):
|
||||
rand_str = str(uuid4())[:8]
|
||||
dst_path = f'/tmp/ee_{rand_str}_symlink_dst'
|
||||
src_path = f'/tmp/ee_{rand_str}_symlink_src'
|
||||
|
||||
src_path.mkdir()
|
||||
def remove_folders():
|
||||
os.unlink(dst_path)
|
||||
shutil.rmtree(src_path)
|
||||
|
||||
request.addfinalizer(remove_folders)
|
||||
os.mkdir(src_path)
|
||||
os.symlink(src_path, dst_path)
|
||||
|
||||
pdd = f'{dst_path}/awx_xxx'
|
||||
|
||||
207
awx/main/tests/unit/utils/test_proxy.py
Normal file
207
awx/main/tests/unit/utils/test_proxy.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2024 Ansible, Inc.
|
||||
# All Rights Reserved.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.utils.proxy import get_first_remote_host_from_headers, is_proxy_in_headers
|
||||
|
||||
|
||||
class TestGetFirstRemoteHostFromHeaders:
|
||||
"""Tests for get_first_remote_host_from_headers function."""
|
||||
|
||||
def _make_mock_request(self, environ):
|
||||
"""Create a mock request with the given environ dict."""
|
||||
request = mock.MagicMock()
|
||||
request.environ = environ
|
||||
return request
|
||||
|
||||
def test_single_value_headers(self):
|
||||
"""Test extraction from headers with single values (no commas)."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"REMOTE_ADDR": "192.168.1.1",
|
||||
"REMOTE_HOST": "client.example.com",
|
||||
}
|
||||
)
|
||||
headers = ["REMOTE_ADDR", "REMOTE_HOST"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
assert result == {"192.168.1.1", "client.example.com"}
|
||||
|
||||
def test_comma_separated_only_first_entry(self):
|
||||
"""Test that only the first entry is extracted from comma-separated values."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"HTTP_X_FORWARDED_FOR": "10.0.0.1, 192.168.1.1, 172.16.0.1",
|
||||
}
|
||||
)
|
||||
headers = ["HTTP_X_FORWARDED_FOR"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
# Only the first IP should be included
|
||||
assert result == {"10.0.0.1"}
|
||||
# Subsequent IPs should NOT be included
|
||||
assert "192.168.1.1" not in result
|
||||
assert "172.16.0.1" not in result
|
||||
|
||||
def test_comma_separated_with_whitespace(self):
|
||||
"""Test that whitespace is properly stripped from first entry."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"HTTP_X_FORWARDED_FOR": " 10.0.0.1 , 192.168.1.1",
|
||||
}
|
||||
)
|
||||
headers = ["HTTP_X_FORWARDED_FOR"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
assert result == {"10.0.0.1"}
|
||||
|
||||
def test_multiple_headers_with_comma_separated(self):
|
||||
"""Test multiple headers where some have comma-separated values."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"HTTP_X_FORWARDED_FOR": "client.example.com, proxy1.example.com, proxy2.example.com",
|
||||
"REMOTE_ADDR": "172.16.0.1",
|
||||
"REMOTE_HOST": "proxy2.example.com",
|
||||
}
|
||||
)
|
||||
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR", "REMOTE_HOST"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
# Should have first entry from X-Forwarded-For plus the single values from other headers
|
||||
assert result == {"client.example.com", "172.16.0.1", "proxy2.example.com"}
|
||||
# Should NOT have subsequent entries from X-Forwarded-For
|
||||
assert "proxy1.example.com" not in result
|
||||
|
||||
def test_empty_header_value(self):
|
||||
"""Test handling of empty header values."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"HTTP_X_FORWARDED_FOR": "",
|
||||
"REMOTE_ADDR": "192.168.1.1",
|
||||
}
|
||||
)
|
||||
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
assert result == {"192.168.1.1"}
|
||||
|
||||
def test_missing_header(self):
|
||||
"""Test handling of headers that don't exist in environ."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"REMOTE_ADDR": "192.168.1.1",
|
||||
}
|
||||
)
|
||||
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR", "REMOTE_HOST"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
assert result == {"192.168.1.1"}
|
||||
|
||||
def test_empty_headers_list(self):
|
||||
"""Test with no headers specified."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"REMOTE_ADDR": "192.168.1.1",
|
||||
}
|
||||
)
|
||||
headers = []
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
assert result == set()
|
||||
|
||||
def test_whitespace_only_first_entry(self):
|
||||
"""Test handling when first entry is whitespace only."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"HTTP_X_FORWARDED_FOR": " , 192.168.1.1",
|
||||
}
|
||||
)
|
||||
headers = ["HTTP_X_FORWARDED_FOR"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
# Empty/whitespace first entry should be skipped
|
||||
assert result == set()
|
||||
|
||||
def test_single_entry_with_trailing_comma(self):
|
||||
"""Test single entry that happens to have a trailing comma."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"HTTP_X_FORWARDED_FOR": "10.0.0.1,",
|
||||
}
|
||||
)
|
||||
headers = ["HTTP_X_FORWARDED_FOR"]
|
||||
|
||||
result = get_first_remote_host_from_headers(request, headers)
|
||||
|
||||
assert result == {"10.0.0.1"}
|
||||
|
||||
|
||||
class TestIsProxyInHeaders:
|
||||
"""Tests for is_proxy_in_headers function."""
|
||||
|
||||
def _make_mock_request(self, environ):
|
||||
"""Create a mock request with the given environ dict."""
|
||||
request = mock.MagicMock()
|
||||
request.environ = environ
|
||||
return request
|
||||
|
||||
def test_proxy_found_in_single_value(self):
|
||||
"""Test proxy detection in single-value header."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"REMOTE_ADDR": "192.168.1.1",
|
||||
}
|
||||
)
|
||||
|
||||
result = is_proxy_in_headers(request, ["192.168.1.1"], ["REMOTE_ADDR"])
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_proxy_found_in_comma_separated(self):
|
||||
"""Test proxy detection in comma-separated header value."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"HTTP_X_FORWARDED_FOR": "10.0.0.1, 192.168.1.1, 172.16.0.1",
|
||||
}
|
||||
)
|
||||
|
||||
result = is_proxy_in_headers(request, ["192.168.1.1"], ["HTTP_X_FORWARDED_FOR"])
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_proxy_not_found(self):
|
||||
"""Test when proxy is not in any header."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"REMOTE_ADDR": "10.0.0.1",
|
||||
}
|
||||
)
|
||||
|
||||
result = is_proxy_in_headers(request, ["192.168.1.1"], ["REMOTE_ADDR"])
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_multiple_proxies_one_match(self):
|
||||
"""Test with multiple allowed proxies, one matches."""
|
||||
request = self._make_mock_request(
|
||||
{
|
||||
"REMOTE_HOST": "proxy.example.com",
|
||||
}
|
||||
)
|
||||
|
||||
result = is_proxy_in_headers(
|
||||
request,
|
||||
["proxy1.example.com", "proxy.example.com", "proxy2.example.com"],
|
||||
["REMOTE_HOST"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -7,7 +7,7 @@ from django.utils.timezone import now
|
||||
from awx.main.models.schedules import _fast_forward_rrule, Schedule
|
||||
from dateutil.rrule import HOURLY, MINUTELY, MONTHLY
|
||||
|
||||
REF_DT = datetime.datetime(2026, 4, 16, tzinfo=datetime.timezone.utc)
|
||||
REF_DT = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -20,10 +20,6 @@ REF_DT = datetime.datetime(2026, 4, 16, tzinfo=datetime.timezone.utc)
|
||||
'DTSTART;TZID=America/New_York:20201118T200000 RRULE:FREQ=MINUTELY;INTERVAL=5;WKST=SU;BYMONTH=2,3;BYMONTHDAY=18;BYHOUR=5;BYMINUTE=35;BYSECOND=0',
|
||||
id='every-5-minutes-at-5:35:00-am-on-the-18th-day-of-feb-or-march-with-week-starting-on-sundays',
|
||||
),
|
||||
pytest.param(
|
||||
'DTSTART;TZID=America/New_York:20251211T130000 RRULE:FREQ=HOURLY;INTERVAL=4;WKST=MO;BYDAY=MO,TU,WE,TH,FR;BYHOUR=1,5,9,13,17,21;BYMINUTE=0',
|
||||
id='every-4-hours-at-1-5-9-13-17-21-am-on-monday-through-friday-with-week-starting-on-monday',
|
||||
),
|
||||
pytest.param(
|
||||
'DTSTART;TZID=America/New_York:20201118T200000 RRULE:FREQ=HOURLY;INTERVAL=5;WKST=SU;BYMONTH=2,3;BYHOUR=5',
|
||||
id='every-5-hours-at-5-am-in-feb-or-march-with-week-starting-on-sundays',
|
||||
@@ -52,7 +48,6 @@ def test_fast_forwarded_rrule_matches_original_occurrence(rrulestr):
|
||||
[
|
||||
pytest.param(datetime.datetime(2024, 12, 1, 0, 0, tzinfo=datetime.timezone.utc), id='ref-dt-out-of-dst'),
|
||||
pytest.param(datetime.datetime(2024, 6, 1, 0, 0, tzinfo=datetime.timezone.utc), id='ref-dt-in-dst'),
|
||||
pytest.param(datetime.datetime(2024, 11, 3, 6, 30, tzinfo=datetime.timezone.utc), id='ref-dt-fall-back-day'),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -63,8 +58,6 @@ def test_fast_forwarded_rrule_matches_original_occurrence(rrulestr):
|
||||
pytest.param(
|
||||
'DTSTART;TZID=Europe/Lisbon:20230703T005800 RRULE:INTERVAL=10;FREQ=MINUTELY;BYHOUR=9,10,11,12,13,14,15,16,17,18,19,20,21', id='rrule-in-dst-by-hour'
|
||||
),
|
||||
pytest.param('DTSTART;TZID=America/New_York:20230313T005800 RRULE:FREQ=MINUTELY;INTERVAL=7', id='rrule-post-dst-7min'),
|
||||
pytest.param('DTSTART;TZID=America/New_York:20230313T005800 RRULE:FREQ=MINUTELY;INTERVAL=13', id='rrule-post-dst-13min'),
|
||||
],
|
||||
)
|
||||
def test_fast_forward_across_dst(rrulestr, ref_dt):
|
||||
|
||||
@@ -48,16 +48,15 @@ def could_be_playbook(project_path, dir_path, filename):
|
||||
# show up.
|
||||
matched = False
|
||||
try:
|
||||
with codecs.open(playbook_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for n, line in enumerate(f):
|
||||
if valid_playbook_re.match(line):
|
||||
matched = True
|
||||
break
|
||||
# Any YAML file can also be encrypted with vault;
|
||||
# allow these to be used as the main playbook.
|
||||
elif n == 0 and line.startswith('$ANSIBLE_VAULT;'):
|
||||
matched = True
|
||||
break
|
||||
for n, line in enumerate(codecs.open(playbook_path, 'r', encoding='utf-8', errors='ignore')):
|
||||
if valid_playbook_re.match(line):
|
||||
matched = True
|
||||
break
|
||||
# Any YAML file can also be encrypted with vault;
|
||||
# allow these to be used as the main playbook.
|
||||
elif n == 0 and line.startswith('$ANSIBLE_VAULT;'):
|
||||
matched = True
|
||||
break
|
||||
except IOError:
|
||||
return None
|
||||
if not matched:
|
||||
|
||||
@@ -55,8 +55,6 @@ def construct_rsyslog_conf_template(settings=settings):
|
||||
)
|
||||
|
||||
def escape_quotes(x):
|
||||
if x is None:
|
||||
return ''
|
||||
return x.replace('"', '\\"')
|
||||
|
||||
if not enabled:
|
||||
|
||||
@@ -45,3 +45,38 @@ def delete_headers_starting_with_http(request: Request, headers: list[str]):
|
||||
for header in headers:
|
||||
if header.startswith('HTTP_'):
|
||||
request.environ.pop(header, None)
|
||||
|
||||
|
||||
def get_first_remote_host_from_headers(request: Request, headers: list[str]) -> set[str]:
|
||||
"""
|
||||
Extract remote host addresses from headers, considering only the first entry
|
||||
in comma-separated values.
|
||||
|
||||
For headers like X-Forwarded-For that may contain multiple IPs (e.g., "client, proxy1, proxy2"),
|
||||
only the first entry (the original client) is considered.
|
||||
|
||||
Example:
|
||||
request.environ = {
|
||||
"HTTP_X_FORWARDED_FOR": "10.0.0.1, 192.168.1.1, 172.16.0.1",
|
||||
"REMOTE_ADDR": "192.168.1.1",
|
||||
"REMOTE_HOST": "proxy.example.com"
|
||||
}
|
||||
headers = ["HTTP_X_FORWARDED_FOR", "REMOTE_ADDR", "REMOTE_HOST"]
|
||||
|
||||
Returns: {"10.0.0.1", "192.168.1.1", "proxy.example.com"}
|
||||
(Only the first IP "10.0.0.1" from X-Forwarded-For, not the full chain)
|
||||
|
||||
request: The DRF/Django request. request.environ dict will be used for extracting hosts
|
||||
headers: A list of header keys to check for remote host values
|
||||
"""
|
||||
remote_hosts = set()
|
||||
|
||||
for header in headers:
|
||||
header_value = request.environ.get(header, '')
|
||||
if header_value:
|
||||
# Only take the first entry if comma-separated
|
||||
first_value = header_value.split(',')[0].strip()
|
||||
if first_value:
|
||||
remote_hosts.add(first_value)
|
||||
|
||||
return remote_hosts
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client
|
||||
|
||||
__all__ = ['retrieve_workload_identity_jwt_with_claims']
|
||||
|
||||
|
||||
def retrieve_workload_identity_jwt_with_claims(
|
||||
claims: dict,
|
||||
audience: str,
|
||||
scope: str,
|
||||
workload_ttl_seconds: int | None = None,
|
||||
) -> str:
|
||||
"""Retrieve JWT token from workload claims.
|
||||
Raises:
|
||||
RuntimeError: if the workload identity client is not configured.
|
||||
"""
|
||||
client = get_workload_identity_client()
|
||||
if client is None:
|
||||
raise RuntimeError("Workload identity client is not configured")
|
||||
kwargs = {"claims": claims, "scope": scope, "audience": audience}
|
||||
if workload_ttl_seconds:
|
||||
kwargs["workload_ttl_seconds"] = workload_ttl_seconds
|
||||
return client.request_workload_jwt(**kwargs).jwt
|
||||
@@ -139,7 +139,7 @@ class WebsocketRelayConnection:
|
||||
except json.JSONDecodeError:
|
||||
logmsg = "Failed to decode message from web node"
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logmsg = "{} {}".format(logmsg, msg.data)
|
||||
logmsg = "{} {}".format(logmsg, payload)
|
||||
logger.warning(logmsg)
|
||||
continue
|
||||
|
||||
@@ -242,7 +242,7 @@ class WebSocketRelayManager(object):
|
||||
except json.JSONDecodeError:
|
||||
logmsg = "Failed to decode message from pg_notify channel `web_ws_heartbeat`"
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logmsg = "{} {}".format(logmsg, notif.payload)
|
||||
logmsg = "{} {}".format(logmsg, payload)
|
||||
logger.warning(logmsg)
|
||||
continue
|
||||
|
||||
|
||||
@@ -21,11 +21,8 @@ DOCUMENTATION = '''
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from importlib.resources import files
|
||||
|
||||
from packaging.version import Version, InvalidVersion
|
||||
|
||||
from ansible.plugins.callback import CallbackBase
|
||||
|
||||
# NOTE: in Ansible 1.2 or later general logging is available without
|
||||
@@ -44,101 +41,6 @@ from ansible.galaxy.collection import find_existing_collections
|
||||
from ansible.utils.collection_loader import AnsibleCollectionConfig
|
||||
import ansible.constants as C
|
||||
|
||||
# External query path constants
|
||||
EXTERNAL_QUERY_COLLECTION = 'ansible_collections.redhat.indirect_accounting'
|
||||
|
||||
|
||||
def _get_query_file_dir():
|
||||
"""Return the query file directory or None."""
|
||||
try:
|
||||
queries_dir = files(EXTERNAL_QUERY_COLLECTION) / 'extensions' / 'audit' / 'external_queries'
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
if not queries_dir.is_dir():
|
||||
return None
|
||||
return queries_dir
|
||||
|
||||
|
||||
def list_external_queries(namespace, name):
|
||||
"""List all available external query versions for a collection.
|
||||
|
||||
Args:
|
||||
namespace: Collection namespace (e.g., 'community')
|
||||
name: Collection name (e.g., 'vmware')
|
||||
|
||||
Returns:
|
||||
List of Version objects for all available query files
|
||||
matching the namespace.name pattern.
|
||||
"""
|
||||
versions = []
|
||||
|
||||
if not (queries_dir := _get_query_file_dir()):
|
||||
return versions
|
||||
|
||||
# Pattern: namespace.name.X.Y.Z.yml where X.Y.Z is the version
|
||||
pattern = re.compile(rf'^{re.escape(namespace)}\.{re.escape(name)}\.(.+)\.yml$')
|
||||
|
||||
for query_file in queries_dir.iterdir():
|
||||
match = pattern.match(query_file.name)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
try:
|
||||
versions.append(Version(version_str))
|
||||
except InvalidVersion:
|
||||
# Skip files with invalid version strings
|
||||
pass
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def find_external_query_with_fallback(namespace, name, installed_version):
|
||||
"""Find external query file with semantic version fallback.
|
||||
|
||||
Args:
|
||||
namespace: Collection namespace (e.g., 'community')
|
||||
name: Collection name (e.g., 'vmware')
|
||||
installed_version: Version string of installed collection (e.g., '4.5.0')
|
||||
|
||||
Returns:
|
||||
Tuple of (query_content, fallback_used, fallback_version) or (None, False, None)
|
||||
- query_content: The query file content if found
|
||||
- fallback_used: True if a fallback version was used instead of exact match
|
||||
- fallback_version: The version string used (for logging)
|
||||
"""
|
||||
if not (queries_dir := _get_query_file_dir()):
|
||||
return None, False, None
|
||||
|
||||
# 1. Try exact version match first
|
||||
exact_file = queries_dir / f'{namespace}.{name}.{installed_version}.yml'
|
||||
if exact_file.exists():
|
||||
with exact_file.open('r') as f:
|
||||
return f.read(), False, installed_version
|
||||
|
||||
# 2. Find compatible fallback (same major version, nearest lower version)
|
||||
try:
|
||||
installed_version_object = Version(installed_version)
|
||||
except InvalidVersion:
|
||||
# Can't do version comparison for fallback
|
||||
return None, False, None
|
||||
available_versions = list_external_queries(namespace, name)
|
||||
if not available_versions:
|
||||
return None, False, None
|
||||
|
||||
# Filter to same major version and versions <= installed version
|
||||
compatible_versions = [v for v in available_versions if v.major == installed_version_object.major and v <= installed_version_object]
|
||||
if not compatible_versions:
|
||||
return None, False, None
|
||||
|
||||
# Select nearest lower version - highest compatible version
|
||||
fallback_version_object = max(compatible_versions)
|
||||
fallback_version_str = str(fallback_version_object)
|
||||
fallback_file = queries_dir / f'{namespace}.{name}.{fallback_version_str}.yml'
|
||||
if fallback_file.exists():
|
||||
with fallback_file.open('r') as f:
|
||||
return f.read(), True, fallback_version_str
|
||||
|
||||
return None, False, None
|
||||
|
||||
|
||||
@with_collection_artifacts_manager
|
||||
def list_collections(artifacts_manager=None):
|
||||
@@ -175,22 +77,10 @@ class CallbackModule(CallbackBase):
|
||||
'version': candidate.ver,
|
||||
}
|
||||
|
||||
# 1. Check for embedded query file (takes precedence)
|
||||
embedded_query_file = files(f'ansible_collections.{candidate.namespace}.{candidate.name}') / 'extensions' / 'audit' / 'event_query.yml'
|
||||
if embedded_query_file.exists():
|
||||
with embedded_query_file.open('r') as f:
|
||||
query_file = files(f'ansible_collections.{candidate.namespace}.{candidate.name}') / 'extensions' / 'audit' / 'event_query.yml'
|
||||
if query_file.exists():
|
||||
with query_file.open('r') as f:
|
||||
collection_print['host_query'] = f.read()
|
||||
self._display.vv(f"Using embedded query for {candidate.fqcn} v{candidate.ver}")
|
||||
else:
|
||||
# 2. Check for external query file with version fallback
|
||||
query_content, fallback_used, version_used = find_external_query_with_fallback(candidate.namespace, candidate.name, candidate.ver)
|
||||
if query_content:
|
||||
collection_print['host_query'] = query_content
|
||||
if fallback_used:
|
||||
# AC5.6: Log when fallback is used
|
||||
self._display.v(f"Using external query {version_used} for {candidate.fqcn} v{candidate.ver}.")
|
||||
else:
|
||||
self._display.v(f"Using external query for {candidate.fqcn} v{candidate.ver}")
|
||||
|
||||
collections_print[candidate.fqcn] = collection_print
|
||||
|
||||
|
||||
@@ -236,7 +236,7 @@
|
||||
changed_when: "'was installed successfully' in galaxy_result.stdout"
|
||||
when:
|
||||
- roles_enabled | bool
|
||||
- req_file | length > 0
|
||||
- req_file
|
||||
tags:
|
||||
- install_roles
|
||||
|
||||
@@ -255,7 +255,7 @@
|
||||
when:
|
||||
- "ansible_version.full is version_compare('2.9', '>=')"
|
||||
- collections_enabled | bool
|
||||
- req_file | length > 0
|
||||
- req_file
|
||||
tags:
|
||||
- install_collections
|
||||
|
||||
@@ -276,7 +276,7 @@
|
||||
- "ansible_version.full is version_compare('2.10', '>=')"
|
||||
- collections_enabled | bool
|
||||
- roles_enabled | bool
|
||||
- req_file | length > 0
|
||||
- req_file
|
||||
tags:
|
||||
- install_collections
|
||||
- install_roles
|
||||
|
||||
@@ -63,15 +63,6 @@ assert_production_settings(DYNACONF, settings_dir, settings_file_path)
|
||||
# Load envvars at the end to allow them to override everything loaded so far
|
||||
load_envvars(DYNACONF)
|
||||
|
||||
# When deployed as part of AAP (RESOURCE_SERVER__URL is set), enforce JWT-only
|
||||
# authentication. This ensures all requests go through the gateway and prevents
|
||||
# direct API access to Controller bypassing the platform's authentication.
|
||||
if DYNACONF.get('RESOURCE_SERVER__URL', None):
|
||||
DYNACONF.set(
|
||||
"REST_FRAMEWORK__DEFAULT_AUTHENTICATION_CLASSES",
|
||||
['ansible_base.jwt_consumer.awx.auth.AwxJWTAuthentication'],
|
||||
)
|
||||
|
||||
# This must run after all custom settings are loaded
|
||||
DYNACONF.update(
|
||||
merge_application_name(DYNACONF),
|
||||
|
||||
@@ -774,7 +774,7 @@ LOGGING = {
|
||||
'awx.conf.settings': {'handlers': ['null'], 'level': 'WARNING'},
|
||||
'awx.main': {'handlers': ['null']},
|
||||
'awx.main.commands.run_callback_receiver': {'handlers': ['callback_receiver'], 'level': 'INFO'}, # very noisey debug-level logs
|
||||
'awx.main.dispatch': {'handlers': ['task_system']},
|
||||
'awx.main.dispatch': {'handlers': ['dispatcher']},
|
||||
'awx.main.consumers': {'handlers': ['console', 'file', 'tower_warnings'], 'level': 'INFO'},
|
||||
'awx.main.rsyslog_configurer': {'handlers': ['rsyslog_configurer']},
|
||||
'awx.main.cache_clear': {'handlers': ['cache_clear']},
|
||||
@@ -1134,7 +1134,6 @@ OPA_REQUEST_RETRIES = 2 # The number of retry attempts for connecting to the OP
|
||||
|
||||
# feature flags
|
||||
FEATURE_INDIRECT_NODE_COUNTING_ENABLED = False
|
||||
FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED = False
|
||||
|
||||
# Dispatcher worker lifetime. If set to None, workers will never be retired
|
||||
# based on age. Note workers will finish their last task before retiring if
|
||||
|
||||
@@ -19,9 +19,6 @@ SECRET_KEY = None
|
||||
# See https://docs.djangoproject.com/en/dev/ref/settings/#allowed-hosts
|
||||
ALLOWED_HOSTS = []
|
||||
|
||||
# In production, trust the X-Forwarded-For header set by the reverse proxy
|
||||
REMOTE_HOST_HEADERS = ['HTTP_X_FORWARDED_FOR']
|
||||
|
||||
# Ansible base virtualenv paths and enablement
|
||||
# only used for deprecated fields and management commands for them
|
||||
BASE_VENV_PATH = os.path.realpath("/var/lib/awx/venv")
|
||||
|
||||
@@ -34,6 +34,9 @@ def get_urlpatterns(prefix=None):
|
||||
re_path(r'^(?:api/)?500.html$', handle_500),
|
||||
re_path(r'^csp-violation/', handle_csp_violation),
|
||||
re_path(r'^login/', handle_login_redirect),
|
||||
# want api/v2/doesnotexist to return a 404, not match the ui urls,
|
||||
# so use a negative lookahead assertion here
|
||||
re_path(r'^(?!api/).*', include('awx.ui.urls', namespace='ui')),
|
||||
]
|
||||
|
||||
if settings.DYNACONF.is_development_mode:
|
||||
@@ -44,12 +47,6 @@ def get_urlpatterns(prefix=None):
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# want api/v2/doesnotexist to return a 404, not match the ui urls,
|
||||
# so use a negative lookahead assertion in the pattern below
|
||||
urlpatterns += [
|
||||
re_path(r'^(?!api/).*', include('awx.ui.urls', namespace='ui')),
|
||||
]
|
||||
|
||||
return urlpatterns
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# This is a cross-platform list tracking distribution packages needed by tests;
|
||||
# see https://docs.openstack.org/infra/bindep/ for additional information.
|
||||
|
||||
python3-pytz [platform:centos-8 platform:rhel-8 platform:centos-9 platform:rhel-9]
|
||||
python38-pytz [platform:centos-8 platform:rhel-8]
|
||||
|
||||
# awxkit
|
||||
python3-requests [platform:centos-8 platform:rhel-8 platform:centos-9 platform:rhel-9]
|
||||
python3-pyyaml [platform:centos-8 platform:rhel-8 platform:centos-9 platform:rhel-9]
|
||||
python38-requests [platform:centos-8 platform:rhel-8]
|
||||
python38-pyyaml [platform:centos-8 platform:rhel-8]
|
||||
|
||||
@@ -55,20 +55,6 @@ options:
|
||||
- Defaults to 10s, but this is handled by the shared module_utils code
|
||||
type: float
|
||||
aliases: [ aap_request_timeout ]
|
||||
max_retries:
|
||||
description:
|
||||
- Specify the max retries to be used with some connection issues.
|
||||
- Defaults to 5.
|
||||
- If value not set, will try environment variable C(AAP_MAX_RETRIES) and then config files.
|
||||
type: int
|
||||
aliases: [ aap_max_retries ]
|
||||
retry_backoff_factor:
|
||||
description:
|
||||
- Backoff factor used when retrying connections.
|
||||
- Defaults to 2.
|
||||
- If value not set, will try environment variable C(AAP_RETRY_BACKOFF_FACTOR) and then config files.
|
||||
type: int
|
||||
aliases: [ aap_retry_backoff_factor ]
|
||||
controller_config_file:
|
||||
description:
|
||||
- Path to the controller config file.
|
||||
|
||||
@@ -76,24 +76,6 @@ options:
|
||||
why: Support for AAP variables
|
||||
alternatives: 'AAP_REQUEST_TIMEOUT'
|
||||
aliases: [ aap_request_timeout ]
|
||||
max_retries:
|
||||
description:
|
||||
- Specify the max retries to be used with some connection issues.
|
||||
- Defaults to 5.
|
||||
- This will not work with the export or import modules.
|
||||
type: int
|
||||
env:
|
||||
- name: AAP_MAX_RETRIES
|
||||
aliases: [ aap_max_retries ]
|
||||
retry_backoff_factor:
|
||||
description:
|
||||
- Backoff factor used when retrying connections.
|
||||
- Defaults to 2.
|
||||
- This will not work with the export or import modules.
|
||||
type: int
|
||||
env:
|
||||
- name: AAP_RETRY_BACKOFF_FACTOR
|
||||
aliases: [ aap_retry_backoff_factor ]
|
||||
notes:
|
||||
- If no I(config_file) is provided we will attempt to use the tower-cli library
|
||||
defaults to find your host information.
|
||||
|
||||
@@ -15,7 +15,6 @@ from ansible.module_utils.six.moves.configparser import ConfigParser, NoOptionEr
|
||||
from base64 import b64encode
|
||||
from socket import getaddrinfo, IPPROTO_TCP
|
||||
import time
|
||||
import random
|
||||
from json import loads, dumps
|
||||
from os.path import isfile, expanduser, split, join, exists, isdir
|
||||
from os import access, R_OK, getcwd, environ, getenv
|
||||
@@ -38,19 +37,6 @@ except ImportError:
|
||||
|
||||
CONTROLLER_BASE_PATH_ENV_VAR = "CONTROLLER_OPTIONAL_API_URLPATTERN_PREFIX"
|
||||
|
||||
# 502/503: request never reached the server — always safe to retry any method
|
||||
ALWAYS_RETRYABLE = {
|
||||
502: ['GET', 'POST', 'PATCH', 'DELETE'], # Bad Gateway
|
||||
503: ['GET', 'POST', 'PATCH', 'DELETE'], # Service Unavailable
|
||||
}
|
||||
|
||||
# 500/504: idempotent methods only — GETs are reads, PATCH/DELETE are
|
||||
# idempotent by definition; POST is excluded unless we know it's safe.
|
||||
IDEMPOTENT_RETRYABLE = {
|
||||
500: ['GET', 'PATCH', 'DELETE'], # Internal Server Error
|
||||
504: ['GET', 'PATCH', 'DELETE'], # Gateway Timeout
|
||||
}
|
||||
|
||||
|
||||
class ConfigFileException(Exception):
|
||||
pass
|
||||
@@ -86,16 +72,6 @@ class ControllerModule(AnsibleModule):
|
||||
aliases=['aap_request_timeout'],
|
||||
required=False,
|
||||
fallback=(env_fallback, ['CONTROLLER_REQUEST_TIMEOUT', 'AAP_REQUEST_TIMEOUT'])),
|
||||
max_retries=dict(
|
||||
type='int',
|
||||
aliases=['aap_max_retries'],
|
||||
required=False,
|
||||
fallback=(env_fallback, ['AAP_MAX_RETRIES'])),
|
||||
retry_backoff_factor=dict(
|
||||
type='int',
|
||||
aliases=['aap_retry_backoff_factor'],
|
||||
required=False,
|
||||
fallback=(env_fallback, ['AAP_RETRY_BACKOFF_FACTOR'])),
|
||||
aap_token=dict(
|
||||
type='raw',
|
||||
no_log=True,
|
||||
@@ -116,16 +92,12 @@ class ControllerModule(AnsibleModule):
|
||||
'password': 'controller_password',
|
||||
'verify_ssl': 'validate_certs',
|
||||
'request_timeout': 'request_timeout',
|
||||
'max_retries': 'max_retries',
|
||||
'retry_backoff_factor': 'retry_backoff_factor',
|
||||
}
|
||||
host = '127.0.0.1'
|
||||
username = None
|
||||
password = None
|
||||
verify_ssl = True
|
||||
request_timeout = 10
|
||||
max_retries = 5
|
||||
retry_backoff_factor = 2
|
||||
authenticated = False
|
||||
config_name = 'tower_cli.cfg'
|
||||
version_checked = False
|
||||
@@ -516,49 +488,6 @@ class ControllerAPIModule(ControllerModule):
|
||||
def resolve_name_to_id(self, endpoint, name_or_id):
|
||||
return self.get_exactly_one(endpoint, name_or_id)['id']
|
||||
|
||||
def is_retryable(self, status_code, method, endpoint):
|
||||
"""
|
||||
Determine whether a failed request is safe to retry.
|
||||
|
||||
Args:
|
||||
status_code (int): HTTP status code returned by the server.
|
||||
method (str): HTTP verb in uppercase ('GET', 'POST', etc.).
|
||||
endpoint (str): The API endpoint path (e.g. '/api/v2/job_templates/1/launch/').
|
||||
|
||||
Returns:
|
||||
bool: True if the request can safely be retried.
|
||||
"""
|
||||
# --- Always safe: 502/503 mean the request never reached AWX ---
|
||||
if method in ALWAYS_RETRYABLE.get(status_code, []):
|
||||
return True
|
||||
|
||||
# --- Safe for inherently idempotent methods (GET, PATCH, DELETE) ---
|
||||
if method in IDEMPOTENT_RETRYABLE.get(status_code, []):
|
||||
return True
|
||||
|
||||
# --- POST/PATCH on 500/504: safe UNLESS the endpoint triggers execution ---
|
||||
if method in ('POST', 'PATCH') and status_code in (500, 504):
|
||||
|
||||
# /launch, /relaunch, /callback etc. — retrying would double-execute
|
||||
# Catches: /job_templates/1/launch/, /workflow_job_templates/1/launch/,
|
||||
# /jobs/1/relaunch/, /ad_hoc_commands/1/relaunch/ …
|
||||
launch_keywords = ('/launch', '/relaunch', '/callback')
|
||||
if any(kw in endpoint for kw in launch_keywords):
|
||||
return False
|
||||
|
||||
# POST to the ad_hoc_commands collection root creates AND immediately
|
||||
# executes the command — not safe to retry.
|
||||
# PATCH to /ad_hoc_commands/<id>/ is fine (handled by PATCH branch above
|
||||
# but would also pass through here correctly).
|
||||
if method == 'POST' and endpoint.rstrip('/').endswith('/ad_hoc_commands'):
|
||||
return False
|
||||
|
||||
# All other POST/PATCH endpoints (create resource, update resource) are
|
||||
# safe: a 500/504 before the DB transaction commits means no side-effect.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def make_request(self, method, endpoint, *args, **kwargs):
|
||||
# In case someone is calling us directly; make sure we were given a method, let's not just assume a GET
|
||||
if not method:
|
||||
@@ -583,155 +512,121 @@ class ControllerAPIModule(ControllerModule):
|
||||
headers.setdefault('Content-Type', 'application/json')
|
||||
kwargs['headers'] = headers
|
||||
|
||||
data = None
|
||||
data = None # Important, if content type is not JSON, this should not be dict type
|
||||
if headers.get('Content-Type', '') == 'application/json':
|
||||
data = dumps(kwargs.get('data', {}))
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Retry loop — wraps only the session.open() + HTTPError handling
|
||||
# Everything above (auth, URL building) happens once before the loop
|
||||
# ----------------------------------------------------------------
|
||||
max_retries = self.max_retries
|
||||
backoff_factor = self.retry_backoff_factor
|
||||
last_response = None
|
||||
|
||||
for attempt in range(max_retries + 1): # attempt 0 = first try
|
||||
|
||||
if attempt > 0:
|
||||
sleep_time = (backoff_factor ** (attempt - 1)) * (0.5 + random.random())
|
||||
self.warn(
|
||||
'Retrying {0} {1} (attempt {2}/{3}) after {4}s due to status {5}'.format(
|
||||
method, url.path, attempt, max_retries, sleep_time,
|
||||
last_response if last_response else 'connection error'
|
||||
)
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
try:
|
||||
response = self.session.open(
|
||||
method, url.geturl(),
|
||||
headers=headers,
|
||||
timeout=self.request_timeout,
|
||||
validate_certs=self.verify_ssl,
|
||||
follow_redirects=True,
|
||||
data=data
|
||||
)
|
||||
|
||||
except (SSLValidationError) as ssl_err:
|
||||
# SSL errors are never retryable — cert problems won't fix themselves
|
||||
self.fail_json(msg="Could not establish a secure connection to your host ({0}): {1}.".format(url.netloc, ssl_err))
|
||||
|
||||
except (ConnectionError) as con_err:
|
||||
# Connection errors may be transient — retry if we have attempts left
|
||||
last_response = 'ConnectionError'
|
||||
if attempt < max_retries:
|
||||
continue
|
||||
self.fail_json(msg="There was a network error of some kind trying to connect to your host ({0}): {1}.".format(url.netloc, con_err))
|
||||
|
||||
except (HTTPError) as he:
|
||||
# ---- Retryable HTTP errors ----
|
||||
if self.is_retryable(he.code, method, url.path):
|
||||
# Exhausted retries on a retryable error go on to regular failure checks.
|
||||
if attempt < max_retries:
|
||||
continue
|
||||
# Exhausted retries - provide informative message
|
||||
self.fail_json(
|
||||
msg="Request to {0} failed with status {1} after {2} retries. "
|
||||
"This may indicate the server is overloaded.".format(url.path, he.code, max_retries)
|
||||
)
|
||||
# ---- Non-retryable HTTP errors (existing behaviour preserved) ----
|
||||
if he.code >= 500:
|
||||
self.fail_json(msg='The host sent back a server error ({1}): {0}. Please check the logs and try again later'.format(url.path, he))
|
||||
elif he.code == 401:
|
||||
self.fail_json(msg='Invalid authentication credentials for {0} (HTTP 401).'.format(url.path))
|
||||
elif he.code == 403:
|
||||
body = he.read()
|
||||
raw = body.decode('utf-8') if isinstance(body, bytes) else str(body)
|
||||
if 'unable to connect to database' in raw.lower():
|
||||
if attempt < max_retries:
|
||||
continue
|
||||
self.fail_json(
|
||||
msg="Request to {0} failed with status 403 (database unavailable) after {1} retries.".format(url.path, max_retries),
|
||||
)
|
||||
# Reuse raw instead of reading again
|
||||
try:
|
||||
err_msg = loads(raw)
|
||||
err_msg = err_msg['detail']
|
||||
except (ValueError, KeyError):
|
||||
err_msg = raw
|
||||
prepend_msg = " Use the collection ansible.platform to modify resources Organization, User, or Team." if (
|
||||
"this resource via the platform ingress") in err_msg else ""
|
||||
self.fail_json(msg="You don't have permission to {1} to {0} (HTTP 403).{2}".format(url.path, method, prepend_msg))
|
||||
elif he.code == 404:
|
||||
if kwargs.get('return_none_on_404', False):
|
||||
return None
|
||||
self.fail_json(msg='The requested object could not be found at {0}.'.format(url.path))
|
||||
elif he.code == 405:
|
||||
self.fail_json(msg="Cannot make a request with the {0} method to this endpoint {1}".format(method, url.path))
|
||||
elif he.code >= 400:
|
||||
page_data = he.read()
|
||||
try:
|
||||
return {'status_code': he.code, 'json': loads(page_data)}
|
||||
except ValueError:
|
||||
return {'status_code': he.code, 'text': page_data}
|
||||
else:
|
||||
self.fail_json(msg="Unexpected return code when calling {0}: {1}".format(url.geturl(), he))
|
||||
|
||||
except (Exception) as e:
|
||||
self.fail_json(msg="There was an unknown error when trying to connect to {2}: {0} {1}".format(type(e).__name__, e, url.geturl()))
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Successful response — fall through from session.open()
|
||||
# The version check and response parsing happen once on success
|
||||
# ----------------------------------------------------------------
|
||||
if not self.version_checked:
|
||||
try:
|
||||
response = self.session.open(
|
||||
method, url.geturl(),
|
||||
headers=headers,
|
||||
timeout=self.request_timeout,
|
||||
validate_certs=self.verify_ssl,
|
||||
follow_redirects=True,
|
||||
data=data
|
||||
)
|
||||
except (SSLValidationError) as ssl_err:
|
||||
self.fail_json(msg="Could not establish a secure connection to your host ({1}): {0}.".format(url.netloc, ssl_err))
|
||||
except (ConnectionError) as con_err:
|
||||
self.fail_json(msg="There was a network error of some kind trying to connect to your host ({1}): {0}.".format(url.netloc, con_err))
|
||||
except (HTTPError) as he:
|
||||
# Sanity check: Did the server send back some kind of internal error?
|
||||
if he.code >= 500:
|
||||
self.fail_json(msg='The host sent back a server error ({1}): {0}. Please check the logs and try again later'.format(url.path, he))
|
||||
# Sanity check: Did we fail to authenticate properly? If so, fail out now; this is always a failure.
|
||||
elif he.code == 401:
|
||||
self.fail_json(msg='Invalid authentication credentials for {0} (HTTP 401).'.format(url.path))
|
||||
# Sanity check: Did we get a forbidden response, which means that the user isn't allowed to do this? Report that.
|
||||
elif he.code == 403:
|
||||
# Hack: Tell the customer to use the platform supported collection when interacting with Org, Team, User Controller endpoints
|
||||
err_msg = he.fp.read().decode('utf-8')
|
||||
try:
|
||||
controller_type = response.getheader('X-API-Product-Name', None)
|
||||
controller_version = response.getheader('X-API-Product-Version', None)
|
||||
except Exception:
|
||||
controller_type = response.info().getheader('X-API-Product-Name', None)
|
||||
controller_version = response.info().getheader('X-API-Product-Version', None)
|
||||
|
||||
parsed_collection_version = Version(self._COLLECTION_VERSION).version
|
||||
if controller_version:
|
||||
parsed_controller_version = Version(controller_version).version
|
||||
if controller_type == 'AWX':
|
||||
collection_compare_ver = parsed_collection_version[0]
|
||||
controller_compare_ver = parsed_controller_version[0]
|
||||
else:
|
||||
collection_compare_ver = "{0}.{1}".format(parsed_collection_version[0], parsed_collection_version[1])
|
||||
controller_compare_ver = '{0}.{1}'.format(parsed_controller_version[0], parsed_controller_version[1])
|
||||
|
||||
if self._COLLECTION_TYPE not in self.collection_to_version or self.collection_to_version[self._COLLECTION_TYPE] != controller_type:
|
||||
self.warn("You are using the {0} version of this collection but connecting to {1}".format(self._COLLECTION_TYPE, controller_type))
|
||||
elif collection_compare_ver != controller_compare_ver:
|
||||
self.warn(
|
||||
"You are running collection version {0} but connecting to {2} version {1}".format(
|
||||
self._COLLECTION_VERSION, controller_version, controller_type
|
||||
)
|
||||
)
|
||||
|
||||
self.version_checked = True
|
||||
|
||||
response_body = ''
|
||||
try:
|
||||
response_body = response.read()
|
||||
except (Exception) as e:
|
||||
self.fail_json(msg="Failed to read response body: {0}".format(e))
|
||||
|
||||
response_json = {}
|
||||
if response_body and response_body != '':
|
||||
# Defensive coding. Handle json responses and non-json responses
|
||||
err_msg = loads(err_msg)
|
||||
err_msg = err_msg['detail']
|
||||
# JSONDecodeError only available on Python 3.5+
|
||||
except ValueError:
|
||||
pass
|
||||
prepend_msg = " Use the collection ansible.platform to modify resources Organization, User, or Team." if (
|
||||
"this resource via the platform ingress") in err_msg else ""
|
||||
self.fail_json(msg="You don't have permission to {1} to {0} (HTTP 403).{2}".format(url.path, method, prepend_msg))
|
||||
# Sanity check: Did we get a 404 response?
|
||||
# Requests with primary keys will return a 404 if there is no response, and we want to consistently trap these.
|
||||
elif he.code == 404:
|
||||
if kwargs.get('return_none_on_404', False):
|
||||
return None
|
||||
self.fail_json(msg='The requested object could not be found at {0}.'.format(url.path))
|
||||
# Sanity check: Did we get a 405 response?
|
||||
# A 405 means we used a method that isn't allowed. Usually this is a bad request, but it requires special treatment because the
|
||||
# API sends it as a logic error in a few situations (e.g. trying to cancel a job that isn't running).
|
||||
elif he.code == 405:
|
||||
self.fail_json(msg="Cannot make a request with the {0} method to this endpoint {1}".format(method, url.path))
|
||||
# Sanity check: Did we get some other kind of error? If so, write an appropriate error message.
|
||||
elif he.code >= 400:
|
||||
# We are going to return a 400 so the module can decide what to do with it
|
||||
page_data = he.read()
|
||||
try:
|
||||
response_json = loads(response_body)
|
||||
except (Exception) as e:
|
||||
self.fail_json(msg="Failed to parse the response json: {0}".format(e))
|
||||
|
||||
if PY2:
|
||||
status_code = response.getcode()
|
||||
return {'status_code': he.code, 'json': loads(page_data)}
|
||||
# JSONDecodeError only available on Python 3.5+
|
||||
except ValueError:
|
||||
return {'status_code': he.code, 'text': page_data}
|
||||
elif he.code == 204 and method == 'DELETE':
|
||||
# A 204 is a normal response for a delete function
|
||||
pass
|
||||
else:
|
||||
status_code = response.status
|
||||
self.fail_json(msg="Unexpected return code when calling {0}: {1}".format(url.geturl(), he))
|
||||
except (Exception) as e:
|
||||
self.fail_json(msg="There was an unknown error when trying to connect to {2}: {0} {1}".format(type(e).__name__, e, url.geturl()))
|
||||
|
||||
return {'status_code': status_code, 'json': response_json}
|
||||
if not self.version_checked:
|
||||
# In PY2 we get back an HTTPResponse object but PY2 is returning an addinfourl
|
||||
# First try to get the headers in PY3 format and then drop down to PY2.
|
||||
try:
|
||||
controller_type = response.getheader('X-API-Product-Name', None)
|
||||
controller_version = response.getheader('X-API-Product-Version', None)
|
||||
except Exception:
|
||||
controller_type = response.info().getheader('X-API-Product-Name', None)
|
||||
controller_version = response.info().getheader('X-API-Product-Version', None)
|
||||
|
||||
parsed_collection_version = Version(self._COLLECTION_VERSION).version
|
||||
if controller_version:
|
||||
parsed_controller_version = Version(controller_version).version
|
||||
if controller_type == 'AWX':
|
||||
collection_compare_ver = parsed_collection_version[0]
|
||||
controller_compare_ver = parsed_controller_version[0]
|
||||
else:
|
||||
collection_compare_ver = "{0}.{1}".format(parsed_collection_version[0], parsed_collection_version[1])
|
||||
controller_compare_ver = '{0}.{1}'.format(parsed_controller_version[0], parsed_controller_version[1])
|
||||
|
||||
if self._COLLECTION_TYPE not in self.collection_to_version or self.collection_to_version[self._COLLECTION_TYPE] != controller_type:
|
||||
self.warn("You are using the {0} version of this collection but connecting to {1}".format(self._COLLECTION_TYPE, controller_type))
|
||||
elif collection_compare_ver != controller_compare_ver:
|
||||
self.warn(
|
||||
"You are running collection version {0} but connecting to {2} version {1}".format(
|
||||
self._COLLECTION_VERSION, controller_version, controller_type
|
||||
)
|
||||
)
|
||||
|
||||
self.version_checked = True
|
||||
|
||||
response_body = ''
|
||||
try:
|
||||
response_body = response.read()
|
||||
except (Exception) as e:
|
||||
self.fail_json(msg="Failed to read response body: {0}".format(e))
|
||||
|
||||
response_json = {}
|
||||
if response_body and response_body != '':
|
||||
try:
|
||||
response_json = loads(response_body)
|
||||
except (Exception) as e:
|
||||
self.fail_json(msg="Failed to parse the response json: {0}".format(e))
|
||||
|
||||
if PY2:
|
||||
status_code = response.getcode()
|
||||
else:
|
||||
status_code = response.status
|
||||
return {'status_code': status_code, 'json': response_json}
|
||||
|
||||
def api_path(self, app_key=None):
|
||||
|
||||
|
||||
@@ -276,7 +276,6 @@ options:
|
||||
- ''
|
||||
- 'github'
|
||||
- 'gitlab'
|
||||
- 'bitbucket_dc'
|
||||
webhook_credential:
|
||||
description:
|
||||
- Personal Access Token for posting back the status to the service API
|
||||
@@ -437,7 +436,7 @@ def main():
|
||||
scm_branch=dict(),
|
||||
ask_scm_branch_on_launch=dict(type='bool'),
|
||||
job_slice_count=dict(type='int'),
|
||||
webhook_service=dict(choices=['github', 'gitlab', 'bitbucket_dc', '']),
|
||||
webhook_service=dict(choices=['github', 'gitlab', '']),
|
||||
webhook_credential=dict(),
|
||||
labels=dict(type="list", elements='str'),
|
||||
notification_templates_started=dict(type="list", elements='str'),
|
||||
|
||||
@@ -117,7 +117,6 @@ options:
|
||||
choices:
|
||||
- github
|
||||
- gitlab
|
||||
- bitbucket_dc
|
||||
webhook_credential:
|
||||
description:
|
||||
- Personal Access Token for posting back the status to the service API
|
||||
@@ -829,7 +828,7 @@ def main():
|
||||
ask_inventory_on_launch=dict(type='bool'),
|
||||
ask_scm_branch_on_launch=dict(type='bool'),
|
||||
ask_limit_on_launch=dict(type='bool'),
|
||||
webhook_service=dict(choices=['github', 'gitlab', 'bitbucket_dc']),
|
||||
webhook_service=dict(choices=['github', 'gitlab']),
|
||||
webhook_credential=dict(),
|
||||
labels=dict(type="list", elements='str'),
|
||||
notification_templates_started=dict(type="list", elements='str'),
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
__metaclass__ = type
|
||||
|
||||
import pytest
|
||||
|
||||
from awx.main.models import JobTemplate, WorkflowJobTemplate
|
||||
|
||||
|
||||
# The backend supports these webhook services on job/workflow templates
|
||||
# (see awx/main/models/mixins.py). The collection modules must accept all of
|
||||
# them in their argument_spec ``choices`` list. This test guards against the
|
||||
# module's choices drifting from the backend -- see AAP-45980, where
|
||||
# ``bitbucket_dc`` had been supported by the API since migration 0188 but was
|
||||
# still being rejected by the job_template/workflow_job_template modules.
|
||||
WEBHOOK_SERVICES = ['github', 'gitlab', 'bitbucket_dc']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize('webhook_service', WEBHOOK_SERVICES)
|
||||
def test_job_template_accepts_webhook_service(run_module, admin_user, project, inventory, webhook_service):
|
||||
result = run_module(
|
||||
'job_template',
|
||||
{
|
||||
'name': 'foo',
|
||||
'playbook': 'helloworld.yml',
|
||||
'project': project.name,
|
||||
'inventory': inventory.name,
|
||||
'webhook_service': webhook_service,
|
||||
'state': 'present',
|
||||
},
|
||||
admin_user,
|
||||
)
|
||||
|
||||
assert not result.get('failed', False), result.get('msg', result)
|
||||
assert result.get('changed', False), result
|
||||
|
||||
jt = JobTemplate.objects.get(name='foo')
|
||||
assert jt.webhook_service == webhook_service
|
||||
|
||||
# Re-running with the same args must be a no-op (idempotence).
|
||||
result = run_module(
|
||||
'job_template',
|
||||
{
|
||||
'name': 'foo',
|
||||
'playbook': 'helloworld.yml',
|
||||
'project': project.name,
|
||||
'inventory': inventory.name,
|
||||
'webhook_service': webhook_service,
|
||||
'state': 'present',
|
||||
},
|
||||
admin_user,
|
||||
)
|
||||
assert not result.get('failed', False), result.get('msg', result)
|
||||
assert not result.get('changed', True), result
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize('webhook_service', WEBHOOK_SERVICES)
|
||||
def test_workflow_job_template_accepts_webhook_service(run_module, admin_user, organization, webhook_service):
|
||||
result = run_module(
|
||||
'workflow_job_template',
|
||||
{
|
||||
'name': 'foo-workflow',
|
||||
'organization': organization.name,
|
||||
'webhook_service': webhook_service,
|
||||
'state': 'present',
|
||||
},
|
||||
admin_user,
|
||||
)
|
||||
|
||||
assert not result.get('failed', False), result.get('msg', result)
|
||||
assert result.get('changed', False), result
|
||||
|
||||
wfjt = WorkflowJobTemplate.objects.get(name='foo-workflow')
|
||||
assert wfjt.webhook_service == webhook_service
|
||||
|
||||
# Re-running with the same args must be a no-op (idempotence).
|
||||
result = run_module(
|
||||
'workflow_job_template',
|
||||
{
|
||||
'name': 'foo-workflow',
|
||||
'organization': organization.name,
|
||||
'webhook_service': webhook_service,
|
||||
'state': 'present',
|
||||
},
|
||||
admin_user,
|
||||
)
|
||||
assert not result.get('failed', False), result.get('msg', result)
|
||||
assert not result.get('changed', True), result
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_job_template_rejects_unknown_webhook_service(run_module, admin_user, project, inventory):
|
||||
result = run_module(
|
||||
'job_template',
|
||||
{
|
||||
'name': 'foo',
|
||||
'playbook': 'helloworld.yml',
|
||||
'project': project.name,
|
||||
'inventory': inventory.name,
|
||||
'webhook_service': 'not_a_real_service',
|
||||
'state': 'present',
|
||||
},
|
||||
admin_user,
|
||||
)
|
||||
assert result.get('failed', False), result
|
||||
assert 'webhook_service' in result.get('msg', '')
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_workflow_job_template_rejects_unknown_webhook_service(run_module, admin_user, organization):
|
||||
result = run_module(
|
||||
'workflow_job_template',
|
||||
{
|
||||
'name': 'foo-workflow',
|
||||
'organization': organization.name,
|
||||
'webhook_service': 'not_a_real_service',
|
||||
'state': 'present',
|
||||
},
|
||||
admin_user,
|
||||
)
|
||||
assert result.get('failed', False), result
|
||||
assert 'webhook_service' in result.get('msg', '')
|
||||
@@ -12,6 +12,15 @@ class ConnectionException(exc.Common):
|
||||
pass
|
||||
|
||||
|
||||
class TokenAuth(requests.auth.AuthBase):
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
|
||||
def __call__(self, request):
|
||||
request.headers['Authorization'] = 'Bearer {0.token}'.format(self)
|
||||
return request
|
||||
|
||||
|
||||
def log_elapsed(r, *args, **kwargs): # requests hook to display API elapsed time
|
||||
log.debug('"{0.request.method} {0.url}" elapsed: {0.elapsed}'.format(r))
|
||||
|
||||
@@ -37,7 +46,7 @@ class Connection(object):
|
||||
self.get(config.api_base_path) # this causes a cookie w/ the CSRF token to be set
|
||||
return dict(next=next)
|
||||
|
||||
def login(self, username=None, password=None, **kwargs):
|
||||
def login(self, username=None, password=None, token=None, **kwargs):
|
||||
if username and password:
|
||||
_next = kwargs.get('next')
|
||||
if _next:
|
||||
@@ -49,14 +58,11 @@ class Connection(object):
|
||||
self.session_cookie_name = historical_response.headers.get('X-API-Session-Cookie-Name')
|
||||
|
||||
self.session_id = self.session.cookies.get(self.session_cookie_name, None)
|
||||
if self.session_id is None and config.get("api_base_path") == "/api/controller/":
|
||||
# Use gateway session cookie name if controller session cookie name is not found
|
||||
self.session_cookie_name = "gateway_sessionid"
|
||||
self.session_id = self.session.cookies.get(self.session_cookie_name, None)
|
||||
|
||||
self.uses_session_cookie = True
|
||||
else:
|
||||
self.session.auth = (username, password)
|
||||
elif token:
|
||||
self.session.auth = TokenAuth(token)
|
||||
else:
|
||||
self.session.auth = None
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ def separate_async_optionals(creation_order):
|
||||
continue
|
||||
by_count = defaultdict(set)
|
||||
has_creates = [cand for cand in group if hasattr(cand, 'dependencies')]
|
||||
counts = dict.fromkeys(has_creates, 0)
|
||||
counts = {has_create: 0 for has_create in has_creates}
|
||||
for has_create in has_creates:
|
||||
for dependency in has_create.dependencies:
|
||||
for compared in [cand for cand in has_creates if cand != has_create]:
|
||||
@@ -212,7 +212,7 @@ class HasCreate(object):
|
||||
dependency_store = kw.get('ds')
|
||||
if dependency_store is None:
|
||||
deps = self.dependencies + self.optional_dependencies
|
||||
self._dependency_store = dict.fromkeys(deps)
|
||||
self._dependency_store = {base_subclass: None for base_subclass in deps}
|
||||
self.ds = DSAdapter(self.__class__.__name__, self._dependency_store)
|
||||
else:
|
||||
self._dependency_store = dependency_store.dependency_store
|
||||
|
||||
@@ -31,23 +31,7 @@ class User(HasCreate, base.Base):
|
||||
payload = self.create_payload(username=username, password=password, **kwargs)
|
||||
self.password = payload.password
|
||||
|
||||
ctrl_users_api = Users(self.connection)
|
||||
# Check if API base path is set to controller, then use gateway endpoint
|
||||
if config.get("api_base_path") == "/api/controller/":
|
||||
# Use gateway endpoint for user creation
|
||||
gw_users_api = Users(self.connection)
|
||||
gw_users_api.endpoint = "/api/gateway/v1/users/"
|
||||
# Cleanup controller attributes
|
||||
payload["is_platform_auditor"] = payload.get("is_system_auditor")
|
||||
payload.pop("is_system_auditor")
|
||||
# Create gw user
|
||||
gw_user = gw_users_api.post(payload)
|
||||
user = ctrl_users_api.get(username=gw_user.username).results.pop()
|
||||
user.json["password"] = payload.password
|
||||
self.update_identity(user)
|
||||
else:
|
||||
# Use default endpoint
|
||||
self.update_identity(ctrl_users_api.post(payload))
|
||||
self.update_identity(Users(self.connection).post(payload))
|
||||
|
||||
if organization:
|
||||
organization.add_user(self)
|
||||
|
||||
@@ -80,62 +80,26 @@ class CLI(object):
|
||||
def help(self):
|
||||
return '--help' in self.argv or '-h' in self.argv
|
||||
|
||||
def _get_non_option_args(self, before_help=False):
|
||||
"""Extract non-option arguments from argv, optionally only those before help flag."""
|
||||
if before_help and self.help:
|
||||
# Find position of help flag
|
||||
help_pos = next((i for i, arg in enumerate(self.argv) if arg in ('--help', '-h')), len(self.argv))
|
||||
args_to_check = self.argv[:help_pos]
|
||||
else:
|
||||
args_to_check = self.argv
|
||||
|
||||
non_option_args = []
|
||||
i = 0
|
||||
while i < len(args_to_check):
|
||||
arg = args_to_check[i]
|
||||
|
||||
if arg == 'awx':
|
||||
# Skip 'awx' token
|
||||
i += 1
|
||||
elif arg.startswith('-'):
|
||||
# This is an option
|
||||
if '=' in arg:
|
||||
# Long option with value: --opt=val
|
||||
i += 1
|
||||
else:
|
||||
# Option without embedded value: --opt or -o
|
||||
i += 1
|
||||
# Only consume next argument if it exists AND doesn't start with '-'
|
||||
# This naturally handles flag-only options (like --verbose)
|
||||
if i < len(args_to_check) and not args_to_check[i].startswith('-'):
|
||||
i += 1
|
||||
else:
|
||||
# This is a positional argument
|
||||
non_option_args.append(arg)
|
||||
i += 1
|
||||
|
||||
return non_option_args
|
||||
|
||||
def _is_main_help_request(self):
|
||||
"""
|
||||
Determine if help request is for main CLI (awx --help) vs subcommand (awx users create --help).
|
||||
Returns True if this is a main CLI help request that should exit early.
|
||||
"""
|
||||
if not self.help:
|
||||
return False
|
||||
|
||||
# If there are non-option arguments before help flag, this is subcommand help
|
||||
return len(self._get_non_option_args(before_help=True)) == 0
|
||||
|
||||
def authenticate(self):
|
||||
"""Configure the current session for authentication.
|
||||
|
||||
Uses Basic authentication when AWXKIT_FORCE_BASIC_AUTH environment variable
|
||||
is set to true, otherwise defaults to session-based authentication.
|
||||
Authentication priority:
|
||||
1. Token authentication (if --conf.token provided)
|
||||
2. Basic authentication (if AWXKIT_FORCE_BASIC_AUTH=true)
|
||||
3. Session-based authentication (default)
|
||||
|
||||
|
||||
For AAP Gateway environments, set AWXKIT_FORCE_BASIC_AUTH=true to bypass
|
||||
session login restrictions.
|
||||
session login restrictions when using username/password.
|
||||
|
||||
"""
|
||||
# Token authentication (if token is provided)
|
||||
token = self.get_config('token')
|
||||
if token:
|
||||
config.use_sessions = False
|
||||
self.root.connection.login(None, None, token=token)
|
||||
return
|
||||
|
||||
# Check if Basic auth is forced via environment variable
|
||||
if config.get('force_basic_auth', False):
|
||||
config.use_sessions = False
|
||||
@@ -271,16 +235,6 @@ class CLI(object):
|
||||
subparsers = self.subparsers[self.resource].add_subparsers(dest='action', metavar='action')
|
||||
subparsers.required = True
|
||||
|
||||
# Add manual help handling for resource-level help
|
||||
# since we disabled add_help=False for resource subparsers
|
||||
if self.help:
|
||||
# Check if this is resource-level help (no action specified)
|
||||
non_option_args = self._get_non_option_args()
|
||||
if len(non_option_args) == 1 and non_option_args[0] == self.resource:
|
||||
# Only resource specified, no action - show resource-level help
|
||||
self.subparsers[self.resource].print_help()
|
||||
return
|
||||
|
||||
# parse the action from OPTIONS
|
||||
parser = ResourceOptionsParser(self.v2, page, self.resource, subparsers)
|
||||
if parser.deprecated:
|
||||
@@ -289,18 +243,6 @@ class CLI(object):
|
||||
description = colored(description, 'yellow')
|
||||
self.subparsers[self.resource].description = description
|
||||
|
||||
# parse any action arguments FIRST before attempting to parse
|
||||
if self.resource != 'settings':
|
||||
for method in ('list', 'modify', 'create'):
|
||||
if method in parser.parser.choices:
|
||||
if method == 'list':
|
||||
http_method = 'GET'
|
||||
elif method == 'modify' and 'PUT' in parser.options:
|
||||
http_method = 'PUT'
|
||||
else:
|
||||
http_method = 'POST'
|
||||
parser.build_query_arguments(method, http_method)
|
||||
|
||||
if from_sphinx:
|
||||
# Our Sphinx plugin runs `parse_action` for *every* available
|
||||
# resource + action in the API so that it can generate usage
|
||||
@@ -313,6 +255,21 @@ class CLI(object):
|
||||
self.parser.parse_known_args(self.argv)[0]
|
||||
except SystemExit:
|
||||
pass
|
||||
else:
|
||||
self.parser.parse_known_args()[0]
|
||||
|
||||
# parse any action arguments
|
||||
if self.resource != 'settings':
|
||||
for method in ('list', 'modify', 'create'):
|
||||
if method in parser.parser.choices:
|
||||
if method == 'list':
|
||||
http_method = 'GET'
|
||||
elif method == 'modify' and 'PUT' in parser.options:
|
||||
http_method = 'PUT'
|
||||
else:
|
||||
http_method = 'POST'
|
||||
parser.build_query_arguments(method, http_method)
|
||||
if from_sphinx:
|
||||
parsed, extra = self.parser.parse_known_args(self.argv)
|
||||
else:
|
||||
parsed, extra = self.parser.parse_known_args()
|
||||
@@ -367,7 +324,6 @@ class CLI(object):
|
||||
self.argv = argv
|
||||
self.parser = HelpfulArgumentParser(add_help=False)
|
||||
self.parser.add_argument(
|
||||
'-h',
|
||||
'--help',
|
||||
action='store_true',
|
||||
help='prints usage information for the awx tool',
|
||||
@@ -377,13 +333,6 @@ class CLI(object):
|
||||
add_output_formatting_arguments(self.parser, env)
|
||||
|
||||
self.args = self.parser.parse_known_args(self.argv)[0]
|
||||
|
||||
# Early return for help to avoid server connection, but only for main CLI help
|
||||
# Allow subcommand help (like 'awx users create --help') to continue processing
|
||||
if self.help and self._is_main_help_request():
|
||||
self.parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
self.verbose = self.get_config('verbose')
|
||||
if self.verbose:
|
||||
logging.basicConfig(level='DEBUG')
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user