mirror of
https://github.com/ansible/awx.git
synced 2026-06-15 11:47:43 -02:30
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffadd3a5a4 | ||
|
|
849f5f796c | ||
|
|
c8981e321e | ||
|
|
d5e5ea3670 | ||
|
|
d566f71ae0 | ||
|
|
c8cb465fde | ||
|
|
49e21d7c1c | ||
|
|
b531151931 | ||
|
|
54857c7a82 | ||
|
|
e03899b581 | ||
|
|
b4f27de4a2 | ||
|
|
5cc467d4cf | ||
|
|
b14b9e1771 | ||
|
|
c4c2779976 | ||
|
|
4bdb11c2a6 | ||
|
|
80f8ee1dec | ||
|
|
f22df56e44 | ||
|
|
fccb6744f9 | ||
|
|
200a68aefa | ||
|
|
9b922f70ed | ||
|
|
e4fa4810eb | ||
|
|
b37f3892b6 | ||
|
|
ec85902b37 | ||
|
|
5eeb854620 | ||
|
|
45480941f8 | ||
|
|
90b7d35554 | ||
|
|
9606366625 | ||
|
|
188c10c7d6 | ||
|
|
2d02a72218 | ||
|
|
d3b40cb57e | ||
|
|
6179b16987 | ||
|
|
cbbd683720 | ||
|
|
2451156fc6 | ||
|
|
83f60cddc2 | ||
|
|
c67d93218f | ||
|
|
eac8968217 | ||
|
|
df771d0e9d | ||
|
|
1213ea6f62 | ||
|
|
b66c0105ae | ||
|
|
d1b3ae53ae | ||
|
|
f3b7d442c3 | ||
|
|
376f964a40 | ||
|
|
c71a49e044 | ||
|
|
99ac0d39dc | ||
|
|
55ad29ac68 | ||
|
|
3fd3b741b6 | ||
|
|
1636abd669 | ||
|
|
d21e0141ce | ||
|
|
e5bae59f5a | ||
|
|
a8afbd1ca3 | ||
|
|
da996c01a0 | ||
|
|
b8c9ae73cd | ||
|
|
d71f18fa44 | ||
|
|
e82a4246f3 | ||
|
|
b83019bde6 | ||
|
|
6d94aa84e7 | ||
|
|
7155400efc |
10
.github/workflows/devel_images.yml
vendored
10
.github/workflows/devel_images.yml
vendored
@@ -13,6 +13,10 @@ on:
|
||||
- stable-*
|
||||
jobs:
|
||||
push-development-images:
|
||||
if: |
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.repository == 'ansible/awx' && (github.ref_name == 'devel' || startsWith(github.ref_name, 'feature_'))) ||
|
||||
(github.repository == 'ansible/tower' && (startsWith(github.ref_name, 'stable-') || startsWith(github.ref_name, 'release_')))
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
permissions:
|
||||
@@ -30,12 +34,6 @@ jobs:
|
||||
make-target: awx-kube-buildx
|
||||
steps:
|
||||
|
||||
- name: Skipping build of awx image for non-awx repository
|
||||
run: |
|
||||
echo "Skipping build of awx image for non-awx repository"
|
||||
exit 0
|
||||
if: matrix.build-targets.image-name == 'awx' && !endsWith(github.repository, '/awx')
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
show-progress: false
|
||||
|
||||
2
.github/workflows/pr_body_check.yml
vendored
2
.github/workflows/pr_body_check.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
permissions:
|
||||
packages: write
|
||||
packages: read
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check for each of the lines
|
||||
|
||||
6
.github/workflows/spec-sync-on-merge.yml
vendored
6
.github/workflows/spec-sync-on-merge.yml
vendored
@@ -16,9 +16,15 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- devel
|
||||
- 'stable-2.[6-9]'
|
||||
- 'stable-2.[1-9][0-9]'
|
||||
workflow_dispatch: # Allow manual triggering for testing
|
||||
jobs:
|
||||
sync-openapi-spec:
|
||||
if: |
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.repository == 'ansible/awx' && (github.ref_name == 'devel' || startsWith(github.ref_name, 'feature_'))) ||
|
||||
(github.repository == 'ansible/tower' && (startsWith(github.ref_name, 'stable-') || startsWith(github.ref_name, 'release_')))
|
||||
name: Sync OpenAPI spec to central repo
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
|
||||
4
.github/workflows/upload_schema.yml
vendored
4
.github/workflows/upload_schema.yml
vendored
@@ -14,6 +14,10 @@ on:
|
||||
- stable-**
|
||||
jobs:
|
||||
push:
|
||||
if: |
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.repository == 'ansible/awx' && (github.ref_name == 'devel' || startsWith(github.ref_name, 'feature_'))) ||
|
||||
(github.repository == 'ansible/tower' && (startsWith(github.ref_name, 'stable-') || startsWith(github.ref_name, 'release_')))
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 60
|
||||
permissions:
|
||||
|
||||
65
.tekton/run-atf-tests-pull-request.yaml
Normal file
65
.tekton/run-atf-tests-pull-request.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
---
|
||||
apiVersion: tekton.dev/v1
|
||||
kind: PipelineRun
|
||||
metadata:
|
||||
name: awx-atf-tests-pull-request
|
||||
annotations:
|
||||
build.appstudio.openshift.io/repo: https://github.com/{{repo_owner}}/{{repo_name}}?rev={{revision}}
|
||||
build.appstudio.redhat.com/commit_sha: '{{revision}}'
|
||||
build.appstudio.redhat.com/pull_request_number: '{{pull_request_number}}'
|
||||
build.appstudio.redhat.com/target_branch: '{{target_branch}}'
|
||||
pipelinesascode.tekton.dev/cancel-in-progress: 'true'
|
||||
pipelinesascode.tekton.dev/max-keep-runs: "3"
|
||||
pipelinesascode.tekton.dev/on-comment: "^/run-atf-tests$"
|
||||
pipelinesascode.tekton.dev/target-namespace: ansible-ci-tenant
|
||||
labels:
|
||||
appstudio.openshift.io/application: '{{repo_owner}}'
|
||||
appstudio.openshift.io/component: '{{repo_owner}}-{{repo_name}}'
|
||||
pipelines.appstudio.openshift.io/type: build
|
||||
spec:
|
||||
timeouts:
|
||||
pipeline: "8h"
|
||||
tasks: "7h"
|
||||
finally: "1h"
|
||||
pipelineRef:
|
||||
resolver: bundles
|
||||
params:
|
||||
- name: name
|
||||
value: aap-api-tests
|
||||
- name: bundle
|
||||
value: quay.io/aap-ci/tekton-catalog/pipeline/test/aap-api-tests:0.1@sha256:50aadd6725a239ab53247deb7cf601d1163ceb1792792fd239a3f37d21a490d7
|
||||
- name: kind
|
||||
value: pipeline
|
||||
- name: secret
|
||||
value: quay-aap-ci-viewer
|
||||
|
||||
taskRunTemplate:
|
||||
serviceAccountName: konflux-integration-runner
|
||||
|
||||
params:
|
||||
- name: git-url
|
||||
value: "{{source_url}}"
|
||||
- name: pipeline-github-org
|
||||
value: "{{repo_owner}}"
|
||||
- name: pipeline-github-repo
|
||||
value: "{{repo_name}}"
|
||||
- name: pipeline-github-target-branch
|
||||
value: '{{target_branch}}'
|
||||
- name: pipeline-github-pr-revision
|
||||
value: "{{revision}}"
|
||||
- name: pipeline-github-pr-number
|
||||
value: "{{pull_request_number}}"
|
||||
- name: aap-dev-component-source-name
|
||||
value: "controller"
|
||||
- name: pytest-number-of-parallel-processes
|
||||
value: "6"
|
||||
|
||||
workspaces:
|
||||
- name: workspace
|
||||
volumeClaimTemplate:
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
@@ -103,6 +103,12 @@ When necessary, remove any AWX containers and images by running the following:
|
||||
|
||||
### Pre commit hooks
|
||||
|
||||
Install the pre-commit hook before contributing:
|
||||
|
||||
```
|
||||
make pre-commit
|
||||
```
|
||||
|
||||
When you attempt to perform a `git commit` there will be a pre-commit hook that gets run before the commit is allowed to your local repository. For example, python's [black](https://pypi.org/project/black/) will be run to test the formatting of any python files.
|
||||
|
||||
While you can use environment variables to skip the pre-commit hooks GitHub will run similar tests and prevent merging of PRs if the tests do not pass.
|
||||
|
||||
39
Makefile
39
Makefile
@@ -10,6 +10,7 @@ KIND_BIN ?= $(shell which kind)
|
||||
CHROMIUM_BIN=/tmp/chrome-linux/chrome
|
||||
GIT_REPO_NAME ?= $(shell basename `git rev-parse --show-toplevel`)
|
||||
GIT_BRANCH ?= $(shell git rev-parse --abbrev-ref HEAD)
|
||||
GIT_IS_WORKTREE := $(shell test -f .git && echo yes)
|
||||
MANAGEMENT_COMMAND ?= awx-manage
|
||||
VERSION ?= $(shell $(PYTHON) tools/scripts/scm_version.py 2> /dev/null)
|
||||
|
||||
@@ -106,6 +107,15 @@ else
|
||||
DOCKER_KUBE_CACHE_FLAG=$(DOCKER_CACHE)
|
||||
endif
|
||||
|
||||
# AWX TUI variables
|
||||
AWX_HOST ?= https://localhost:8043
|
||||
AWX_USER ?= admin
|
||||
AWX_PASSWORD ?= $$(awk -F"'" '/^admin_password:/{print $$2}' tools/docker-compose/_sources/secrets/admin_password.yml 2>/dev/null || echo "admin")
|
||||
AWX_VERIFY_SSL ?= false
|
||||
|
||||
# For git worktree to find the referenced git dir
|
||||
GIT_COMMON_DIR := $(shell git rev-parse --git-common-dir 2>/dev/null || echo .git)
|
||||
|
||||
.PHONY: awx-link clean clean-tmp clean-venv requirements requirements_dev \
|
||||
update_requirements upgrade_requirements update_requirements_dev \
|
||||
docker_update_requirements docker_upgrade_requirements docker_update_requirements_dev \
|
||||
@@ -113,7 +123,7 @@ endif
|
||||
receiver test test_unit test_coverage coverage_html \
|
||||
sdist \
|
||||
VERSION PYTHON_VERSION docker-compose-sources \
|
||||
.git/hooks/pre-commit
|
||||
pre-commit
|
||||
|
||||
clean-tmp:
|
||||
rm -rf tmp/
|
||||
@@ -342,11 +352,10 @@ black: reports
|
||||
@command -v black >/dev/null 2>&1 || { echo "could not find black on your PATH, you may need to \`pip install black\`, or set AWX_IGNORE_BLACK=1" && exit 1; }
|
||||
@(set -o pipefail && $@ $(BLACK_ARGS) awx awxkit awx_collection | tee reports/$@.report)
|
||||
|
||||
.git/hooks/pre-commit:
|
||||
@echo "if [ -x pre-commit.sh ]; then" > .git/hooks/pre-commit
|
||||
@echo " ./pre-commit.sh;" >> .git/hooks/pre-commit
|
||||
@echo "fi" >> .git/hooks/pre-commit
|
||||
@chmod +x .git/hooks/pre-commit
|
||||
$(GIT_COMMON_DIR)/hooks/pre-commit:
|
||||
ln -sf ../../pre-commit.sh $(GIT_COMMON_DIR)/hooks/pre-commit
|
||||
|
||||
pre-commit: $(GIT_COMMON_DIR)/hooks/pre-commit
|
||||
|
||||
genschema: awx-link reports
|
||||
@if [ "$(VENV_BASE)" ]; then \
|
||||
@@ -521,7 +530,7 @@ ifneq ($(ADMIN_PASSWORD),)
|
||||
EXTRA_SOURCES_ANSIBLE_OPTS := -e admin_password=$(ADMIN_PASSWORD) $(EXTRA_SOURCES_ANSIBLE_OPTS)
|
||||
endif
|
||||
|
||||
docker-compose-sources: .git/hooks/pre-commit
|
||||
docker-compose-sources:
|
||||
@if [ $(MINIKUBE_CONTAINER_GROUP) = true ]; then\
|
||||
$(ANSIBLE_PLAYBOOK) -i tools/docker-compose/inventory -e minikube_setup=$(MINIKUBE_SETUP) tools/docker-compose-minikube/deploy.yml; \
|
||||
fi;
|
||||
@@ -553,7 +562,7 @@ docker-compose: awx/projects docker-compose-sources
|
||||
$(MAKE) docker-compose-up
|
||||
|
||||
docker-compose-up:
|
||||
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml $(COMPOSE_OPTS) up $(COMPOSE_UP_OPTS) --remove-orphans
|
||||
$(if $(GIT_IS_WORKTREE),SETUPTOOLS_SCM_PRETEND_VERSION="$(VERSION)") $(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml $(COMPOSE_OPTS) up $(COMPOSE_UP_OPTS) --remove-orphans
|
||||
|
||||
docker-compose-down:
|
||||
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml $(COMPOSE_OPTS) down --remove-orphans
|
||||
@@ -571,6 +580,20 @@ docker-compose-runtest: awx/projects docker-compose-sources
|
||||
docker-compose-build-schema: awx/projects docker-compose-sources
|
||||
$(DOCKER_COMPOSE) -f tools/docker-compose/_sources/docker-compose.yml run --rm --service-ports --no-deps awx_1 make genschema
|
||||
|
||||
awx-tui:
|
||||
@if ! command -v awx-tui > /dev/null 2>&1; then \
|
||||
$(PYTHON) -m pip install awx-tui; \
|
||||
fi
|
||||
@if [ -f "$(HOME)/.config/awx-tui/config.yaml" ]; then \
|
||||
$(PYTHON) -m awx_tui.main; \
|
||||
else \
|
||||
AWX_HOST=$(AWX_HOST) \
|
||||
AWX_USER=$(AWX_USER) \
|
||||
AWX_PASSWORD=$(AWX_PASSWORD) \
|
||||
AWX_VERIFY_SSL=$(AWX_VERIFY_SSL) \
|
||||
$(PYTHON) -m awx_tui.main --host $(AWX_HOST); \
|
||||
fi
|
||||
|
||||
SCHEMA_DIFF_BASE_FOLDER ?= awx
|
||||
SCHEMA_DIFF_BASE_BRANCH ?= devel
|
||||
detect-schema-change: genschema
|
||||
|
||||
@@ -52,14 +52,6 @@ except ImportError: # pragma: no cover
|
||||
MODE = 'production'
|
||||
|
||||
|
||||
try:
|
||||
import django # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
from django.db import connection
|
||||
|
||||
|
||||
def prepare_env():
|
||||
# Update the default settings environment variable based on current mode.
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings')
|
||||
@@ -79,14 +71,6 @@ def manage():
|
||||
from django.conf import settings
|
||||
from django.core.management import execute_from_command_line
|
||||
|
||||
# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
|
||||
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
|
||||
# The return of connection.pg_version is something like 12013
|
||||
if not os.getenv('SKIP_PG_VERSION_CHECK', False) and not MODE == 'development':
|
||||
if (connection.pg_version // 10000) < 12:
|
||||
sys.stderr.write("At a minimum, postgres version 12 is required\n")
|
||||
sys.exit(1)
|
||||
|
||||
if len(sys.argv) >= 2 and sys.argv[1] in ('version', '--version'): # pragma: no cover
|
||||
sys.stdout.write('%s\n' % __version__)
|
||||
# If running as a user without permission to read settings, display an
|
||||
|
||||
@@ -272,7 +272,10 @@ class APIView(views.APIView):
|
||||
response = self.handle_exception(self.__init_request_error__)
|
||||
if response.status_code == 401:
|
||||
if response.data and 'detail' in response.data:
|
||||
response.data['detail'] += _(' To establish a login session, visit') + ' /api/login/.'
|
||||
if getattr(settings, 'RESOURCE_SERVER__URL', None):
|
||||
response.data['detail'] += _(' Direct access is not allowed, authenticate via the platform gateway.')
|
||||
else:
|
||||
response.data['detail'] += _(' To establish a login session, visit') + ' /api/login/.'
|
||||
logger.info(status_msg)
|
||||
else:
|
||||
logger.warning(status_msg)
|
||||
|
||||
@@ -120,8 +120,7 @@ from awx.main.utils.named_url_graph import reset_counters
|
||||
from awx.main.utils.inventory_vars import update_group_variables
|
||||
from awx.main.scheduler.task_manager_models import TaskManagerModels
|
||||
from awx.main.redact import UriCleaner, REPLACE_STR
|
||||
from awx.main.signals import update_inventory_computed_fields
|
||||
|
||||
from awx.main.tasks.system import update_inventory_computed_fields
|
||||
|
||||
from awx.main.validators import vars_validate_or_raise
|
||||
|
||||
@@ -175,8 +174,8 @@ SUMMARIZABLE_FK_FIELDS = {
|
||||
'workflow_approval': DEFAULT_SUMMARY_FIELDS + ('timeout',),
|
||||
'schedule': DEFAULT_SUMMARY_FIELDS + ('next_run',),
|
||||
'unified_job_template': DEFAULT_SUMMARY_FIELDS + ('unified_job_type',),
|
||||
'last_job': DEFAULT_SUMMARY_FIELDS + ('finished', 'status', 'failed', 'license_error', 'canceled_on'),
|
||||
'last_job_host_summary': DEFAULT_SUMMARY_FIELDS + ('failed',),
|
||||
# last_job and last_job_host_summary are derived from JobHostSummary in HostSerializer,
|
||||
# not from the stale FK fields on Host.
|
||||
'last_update': DEFAULT_SUMMARY_FIELDS + ('status', 'failed', 'license_error'),
|
||||
'current_update': DEFAULT_SUMMARY_FIELDS + ('status', 'failed', 'license_error'),
|
||||
'current_job': DEFAULT_SUMMARY_FIELDS + ('status', 'failed', 'license_error'),
|
||||
@@ -1022,7 +1021,7 @@ class UnifiedJobStdoutSerializer(UnifiedJobSerializer):
|
||||
|
||||
|
||||
class UserSerializer(BaseSerializer):
|
||||
password = serializers.CharField(required=False, default='', help_text=_('Field used to change the password.'))
|
||||
password = serializers.CharField(required=False, default='', allow_blank=True, help_text=_('Field used to change the password.'))
|
||||
is_system_auditor = serializers.BooleanField(default=False)
|
||||
show_capabilities = ['edit', 'delete']
|
||||
|
||||
@@ -1838,19 +1837,35 @@ class HostSerializer(BaseSerializerWithVariables):
|
||||
res['ansible_facts'] = self.reverse('api:host_ansible_facts_detail', kwargs={'pk': obj.instance_id})
|
||||
if obj.inventory:
|
||||
res['inventory'] = self.reverse('api:inventory_detail', kwargs={'pk': obj.inventory.pk})
|
||||
if obj.last_job:
|
||||
res['last_job'] = self.reverse('api:job_detail', kwargs={'pk': obj.last_job.pk})
|
||||
if obj.last_job_host_summary:
|
||||
res['last_job_host_summary'] = self.reverse('api:job_host_summary_detail', kwargs={'pk': obj.last_job_host_summary.pk})
|
||||
last_summary = obj.latest_summary
|
||||
if last_summary:
|
||||
res['last_job_host_summary'] = self.reverse('api:job_host_summary_detail', kwargs={'pk': last_summary.pk})
|
||||
if last_summary.job_id:
|
||||
res['last_job'] = self.reverse('api:job_detail', kwargs={'pk': last_summary.job_id})
|
||||
return res
|
||||
|
||||
def get_summary_fields(self, obj):
|
||||
d = super(HostSerializer, self).get_summary_fields(obj)
|
||||
try:
|
||||
d['last_job']['job_template_id'] = obj.last_job.job_template.id
|
||||
d['last_job']['job_template_name'] = obj.last_job.job_template.name
|
||||
except (KeyError, AttributeError):
|
||||
pass
|
||||
last_summary = obj.latest_summary
|
||||
if last_summary:
|
||||
d['last_job_host_summary'] = OrderedDict()
|
||||
d['last_job_host_summary']['id'] = last_summary.id
|
||||
d['last_job_host_summary']['failed'] = last_summary.failed
|
||||
try:
|
||||
last_job = last_summary.job
|
||||
d['last_job'] = OrderedDict()
|
||||
for field in DEFAULT_SUMMARY_FIELDS + ('finished', 'status', 'failed', 'canceled_on'):
|
||||
fval = getattr(last_job, field, None)
|
||||
if fval is not None:
|
||||
d['last_job'][field] = fval
|
||||
if last_job.job_template:
|
||||
d['last_job']['job_template_id'] = last_job.job_template.id
|
||||
d['last_job']['job_template_name'] = last_job.job_template.name
|
||||
except ObjectDoesNotExist:
|
||||
pass
|
||||
else:
|
||||
d.pop('last_job', None)
|
||||
d.pop('last_job_host_summary', None)
|
||||
if has_model_field_prefetched(obj, 'groups'):
|
||||
group_list = sorted([{'id': g.id, 'name': g.name} for g in obj.groups.all()], key=lambda x: x['id'])[:5]
|
||||
else:
|
||||
@@ -1925,14 +1940,16 @@ class HostSerializer(BaseSerializerWithVariables):
|
||||
return ret
|
||||
if 'inventory' in ret and not obj.inventory:
|
||||
ret['inventory'] = None
|
||||
if 'last_job' in ret and not obj.last_job:
|
||||
ret['last_job'] = None
|
||||
if 'last_job_host_summary' in ret and not obj.last_job_host_summary:
|
||||
ret['last_job_host_summary'] = None
|
||||
last_summary = obj.latest_summary
|
||||
if 'last_job' in ret:
|
||||
ret['last_job'] = last_summary.job_id if last_summary else None
|
||||
if 'last_job_host_summary' in ret:
|
||||
ret['last_job_host_summary'] = last_summary.pk if last_summary else None
|
||||
return ret
|
||||
|
||||
def get_has_active_failures(self, obj):
|
||||
return bool(obj.last_job_host_summary and obj.last_job_host_summary.failed)
|
||||
last_summary = obj.latest_summary
|
||||
return bool(last_summary and last_summary.failed)
|
||||
|
||||
def get_has_inventory_sources(self, obj):
|
||||
return obj.inventory_sources.exists()
|
||||
@@ -2079,9 +2096,17 @@ class BulkHostCreateSerializer(serializers.Serializer):
|
||||
if request and not request.user.is_superuser:
|
||||
if request.user not in inv.admin_role:
|
||||
raise serializers.ValidationError(_(f'Inventory with id {inv.id} not found or lack permissions to add hosts.'))
|
||||
current_hostnames = set(inv.hosts.values_list('name', flat=True))
|
||||
|
||||
# Performance optimization (AAP-67978): Instead of loading ALL host names from
|
||||
# the inventory, only check if the specific new names already exist in the database.
|
||||
new_names = [host['name'] for host in attrs['hosts']]
|
||||
duplicate_new_names = [n for n in new_names if n in current_hostnames or new_names.count(n) > 1]
|
||||
|
||||
new_name_counts = Counter(new_names)
|
||||
duplicates_in_new = [name for name, count in new_name_counts.items() if count > 1]
|
||||
unique_new_names = list(new_name_counts.keys())
|
||||
existing_duplicates = list(Host.objects.filter(inventory=inv, name__in=unique_new_names).values_list('name', flat=True))
|
||||
duplicate_new_names = list(set(duplicates_in_new + existing_duplicates))
|
||||
|
||||
if duplicate_new_names:
|
||||
raise serializers.ValidationError(_(f'Hostnames must be unique in an inventory. Duplicates found: {duplicate_new_names}'))
|
||||
|
||||
@@ -2932,6 +2957,19 @@ class CredentialTypeSerializer(BaseSerializer):
|
||||
field['label'] = _(field['label'])
|
||||
if 'help_text' in field:
|
||||
field['help_text'] = _(field['help_text'])
|
||||
|
||||
# Deep copy inputs to avoid modifying the original model data
|
||||
inputs = value.get('inputs')
|
||||
if not isinstance(inputs, dict):
|
||||
inputs = {}
|
||||
value['inputs'] = copy.deepcopy(inputs)
|
||||
fields = value['inputs'].get('fields', [])
|
||||
if not isinstance(fields, list):
|
||||
fields = []
|
||||
|
||||
# Normalize fields and filter out internal fields
|
||||
value['inputs']['fields'] = [f for f in fields if not f.get('internal')]
|
||||
|
||||
return value
|
||||
|
||||
def filter_field_metadata(self, fields, method):
|
||||
@@ -4122,9 +4160,28 @@ class LaunchConfigurationBaseSerializer(BaseSerializer):
|
||||
attrs['extra_data'][key] = db_extra_data[key]
|
||||
|
||||
# Build unsaved version of this config, use it to detect prompts errors
|
||||
# Capture keys before _build_mock_obj pops pseudo-fields from attrs
|
||||
incoming_attr_keys = set(attrs.keys())
|
||||
mock_obj = self._build_mock_obj(attrs)
|
||||
if set(list(ujt.get_ask_mapping().keys()) + ['extra_data']) & set(attrs.keys()):
|
||||
accepted, rejected, errors = ujt._accept_or_ignore_job_kwargs(_exclude_errors=self.exclude_errors, **mock_obj.prompts_dict())
|
||||
ask_mapping_keys = set(ujt.get_ask_mapping().keys())
|
||||
requested_prompt_fields = incoming_attr_keys & ask_mapping_keys
|
||||
if 'extra_data' in incoming_attr_keys:
|
||||
requested_prompt_fields.add('extra_vars')
|
||||
requested_prompt_fields.add('survey_passwords')
|
||||
|
||||
# prompts_dict() pulls persisted M2M state (labels, credentials,
|
||||
# instance_groups) via the instance pk. Only re-validate the full prompt
|
||||
# state when the caller is switching the underlying template; otherwise
|
||||
# restrict validation to the fields the request explicitly provided.
|
||||
if 'unified_job_template' in attrs:
|
||||
prompts_to_validate = mock_obj.prompts_dict()
|
||||
elif requested_prompt_fields:
|
||||
prompts_to_validate = {k: v for k, v in mock_obj.prompts_dict().items() if k in requested_prompt_fields}
|
||||
else:
|
||||
prompts_to_validate = None
|
||||
|
||||
if prompts_to_validate is not None:
|
||||
accepted, rejected, errors = ujt._accept_or_ignore_job_kwargs(_exclude_errors=self.exclude_errors, **prompts_to_validate)
|
||||
else:
|
||||
# Only perform validation of prompts if prompts fields are provided
|
||||
errors = {}
|
||||
@@ -5393,7 +5450,11 @@ class SchedulePreviewSerializer(BaseSerializer):
|
||||
for a_rule in match_multiple_rrule:
|
||||
if 'interval' not in a_rule.lower():
|
||||
errors.append("{0}: {1}".format(_('INTERVAL required in rrule'), a_rule))
|
||||
elif 'secondly' in a_rule.lower():
|
||||
else:
|
||||
match_interval = re.match(r".*?INTERVAL=([0-9]+)", a_rule)
|
||||
if match_interval and int(match_interval.group(1)) < 1:
|
||||
errors.append("{0}: {1}".format(_("INTERVAL must be a positive integer"), a_rule))
|
||||
if 'secondly' in a_rule.lower():
|
||||
errors.append("{0}: {1}".format(_('SECONDLY is not supported'), a_rule))
|
||||
if re.match(by_day_with_numeric_prefix, a_rule):
|
||||
errors.append("{0}: {1}".format(_("BYDAY with numeric prefix not supported"), a_rule))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
---
|
||||
collections:
|
||||
- name: ansible.receptor
|
||||
version: 2.0.6
|
||||
version: 2.0.8
|
||||
|
||||
@@ -14,13 +14,14 @@ import sys
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from collections import OrderedDict
|
||||
from jwt import decode as _jwt_decode
|
||||
|
||||
from urllib3.exceptions import ConnectTimeoutError
|
||||
|
||||
# Django
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError, ObjectDoesNotExist
|
||||
from django.db.models import Q, Sum, Count
|
||||
from django.db.models import Q, Sum, Count, Subquery, OuterRef
|
||||
from django.db import IntegrityError, ProgrammingError, transaction, connection
|
||||
from django.db.models.fields.related import ManyToManyField, ForeignKey
|
||||
from django.db.models.functions import Trunc
|
||||
@@ -58,8 +59,13 @@ from drf_spectacular.utils import extend_schema_view, extend_schema
|
||||
from ansible_base.lib.utils.requests import get_remote_hosts
|
||||
from ansible_base.rbac.models import RoleEvaluation
|
||||
from ansible_base.lib.utils.schema import extend_schema_if_available
|
||||
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
|
||||
|
||||
# flags
|
||||
from flags.state import flag_enabled
|
||||
|
||||
# AWX
|
||||
from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims
|
||||
from awx.main.tasks.system import send_notifications, update_inventory_computed_fields
|
||||
from awx.main.access import get_user_queryset
|
||||
from awx.api.generics import (
|
||||
@@ -203,11 +209,12 @@ class DashboardView(APIView):
|
||||
groups_inventory_failed = models.Group.objects.filter(inventory_sources__last_job_failed=True).count()
|
||||
data['groups'] = {'url': reverse('api:group_list', request=request), 'total': user_groups.count(), 'inventory_failed': groups_inventory_failed}
|
||||
|
||||
user_hosts = get_user_queryset(request.user, models.Host)
|
||||
user_hosts_failed = user_hosts.filter(last_job_host_summary__failed=True)
|
||||
user_hosts = get_user_queryset(request.user, models.Host).exclude(inventory__kind='constructed')
|
||||
latest_summary_failed = Subquery(models.JobHostSummary.objects.filter(host_id=OuterRef('pk')).order_by('-id').values('failed')[:1])
|
||||
user_hosts_failed = user_hosts.annotate(_latest_failed=latest_summary_failed).filter(_latest_failed=True)
|
||||
|
||||
data['hosts'] = {
|
||||
'url': reverse('api:host_list', request=request),
|
||||
'failures_url': reverse('api:host_list', request=request) + "?last_job_host_summary__failed=True",
|
||||
'total': user_hosts.count(),
|
||||
'failed': user_hosts_failed.count(),
|
||||
}
|
||||
@@ -794,22 +801,11 @@ class TeamRolesList(SubListAttachDetachAPIView):
|
||||
data = dict(msg=_("You cannot grant system-level permissions to a team."))
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
team = get_object_or_404(models.Team, pk=self.kwargs['pk'])
|
||||
credential_content_type = ContentType.objects.get_for_model(models.Credential)
|
||||
if role.content_type == credential_content_type:
|
||||
if not role.content_object.organization:
|
||||
data = dict(
|
||||
msg=_("You cannot grant access to a credential that is not assigned to an organization (private credentials cannot be assigned to teams)")
|
||||
)
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
elif role.content_object.organization.id != team.organization.id:
|
||||
if not request.user.is_superuser:
|
||||
data = dict(
|
||||
msg=_(
|
||||
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
|
||||
)
|
||||
)
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
if not request.data.get('disassociate'):
|
||||
team = get_object_or_404(models.Team, pk=self.kwargs['pk'])
|
||||
content_object = role.content_object
|
||||
if hasattr(content_object, 'validate_role_assignment'):
|
||||
content_object.validate_role_assignment(team, role_definition=None, requesting_user=request.user)
|
||||
|
||||
return super(TeamRolesList, self).post(request, *args, **kwargs)
|
||||
|
||||
@@ -1268,19 +1264,12 @@ class UserRolesList(SubListAttachDetachAPIView):
|
||||
if not sub_id:
|
||||
return super(UserRolesList, self).post(request)
|
||||
|
||||
user = get_object_or_400(models.User, pk=self.kwargs['pk'])
|
||||
role = get_object_or_400(models.Role, pk=sub_id)
|
||||
|
||||
content_types = ContentType.objects.get_for_models(models.Organization, models.Team, models.Credential) # dict of {model: content_type}
|
||||
credential_content_type = content_types[models.Credential]
|
||||
if role.content_type == credential_content_type:
|
||||
if 'disassociate' not in request.data and role.content_object.organization and user not in role.content_object.organization.member_role:
|
||||
data = dict(msg=_("You cannot grant credential access to a user not in the credentials' organization"))
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not role.content_object.organization and not request.user.is_superuser:
|
||||
data = dict(msg=_("You cannot grant private credential access to another user"))
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
if not request.data.get('disassociate'):
|
||||
role = get_object_or_400(models.Role, pk=sub_id)
|
||||
user = get_object_or_400(models.User, pk=self.kwargs['pk'])
|
||||
content_object = role.content_object
|
||||
if hasattr(content_object, 'validate_role_assignment'):
|
||||
content_object.validate_role_assignment(user, role_definition=None, requesting_user=request.user)
|
||||
|
||||
return super(UserRolesList, self).post(request, *args, **kwargs)
|
||||
|
||||
@@ -1595,7 +1584,175 @@ class CredentialCopy(CopyAPIView):
|
||||
resource_purpose = 'copy of a credential'
|
||||
|
||||
|
||||
class CredentialExternalTest(SubDetailAPIView):
|
||||
class OIDCCredentialTestMixin:
|
||||
"""
|
||||
Mixin to add OIDC workload identity token support to credential test endpoints.
|
||||
|
||||
This mixin provides methods to handle OIDC-enabled external credentials that use
|
||||
workload identity tokens for authentication.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_workload_identity_token(job_template: models.JobTemplate, audience: str) -> str:
|
||||
"""Generate a workload identity token for a job template.
|
||||
|
||||
Args:
|
||||
job_template: The JobTemplate instance to generate claims for
|
||||
audience: The JWT audience claim value
|
||||
|
||||
Returns:
|
||||
str: The generated JWT token
|
||||
"""
|
||||
claims = {
|
||||
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: job_template.organization.name,
|
||||
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: job_template.organization.id,
|
||||
AutomationControllerJobScope.CLAIM_PROJECT_NAME: job_template.project.name,
|
||||
AutomationControllerJobScope.CLAIM_PROJECT_ID: job_template.project.id,
|
||||
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: job_template.name,
|
||||
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: job_template.id,
|
||||
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: job_template.playbook,
|
||||
}
|
||||
return retrieve_workload_identity_jwt_with_claims(
|
||||
claims=claims,
|
||||
audience=audience,
|
||||
scope=AutomationControllerJobScope.name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _decode_jwt_payload_for_display(jwt_token):
|
||||
"""Decode JWT payload for display purposes only (signature not verified).
|
||||
|
||||
This is safe because the JWT was just created by AWX and is only decoded
|
||||
to show the user what claims are being sent to the external system.
|
||||
The external system will perform proper signature verification.
|
||||
|
||||
Args:
|
||||
jwt_token: The JWT token to decode
|
||||
|
||||
Returns:
|
||||
dict: The decoded JWT payload
|
||||
"""
|
||||
return _jwt_decode(jwt_token, algorithms=["RS256"], options={"verify_signature": False}) # NOSONAR python:S5659
|
||||
|
||||
def _has_workload_identity_token(self, credential_type_inputs):
|
||||
"""Check if credential type has an internal workload_identity_token field.
|
||||
|
||||
Args:
|
||||
credential_type_inputs: The inputs dict from a credential type
|
||||
|
||||
Returns:
|
||||
bool: True if the credential type has a workload_identity_token field marked as internal
|
||||
"""
|
||||
fields = credential_type_inputs.get('fields', []) if isinstance(credential_type_inputs, dict) else []
|
||||
return any(field.get('internal') and field.get('id') == 'workload_identity_token' for field in fields)
|
||||
|
||||
def _validate_and_get_job_template(self, job_template_id):
|
||||
"""Validate job template ID and return the JobTemplate instance.
|
||||
|
||||
Args:
|
||||
job_template_id: The job template ID from metadata
|
||||
|
||||
Returns:
|
||||
JobTemplate instance
|
||||
|
||||
Raises:
|
||||
ParseError: If job_template_id is invalid or not found
|
||||
"""
|
||||
if job_template_id is None:
|
||||
raise ParseError(_('Job template ID is required.'))
|
||||
|
||||
try:
|
||||
return models.JobTemplate.objects.get(id=int(job_template_id))
|
||||
except ValueError:
|
||||
raise ParseError(_('Job template ID must be an integer.'))
|
||||
except models.JobTemplate.DoesNotExist:
|
||||
raise ParseError(_('Job template with ID %(id)s does not exist.') % {'id': job_template_id})
|
||||
|
||||
def _handle_oidc_credential_test(self, backend_kwargs):
|
||||
"""
|
||||
Handle OIDC workload identity token generation for external credential test endpoints.
|
||||
|
||||
This method should only be called when FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is enabled
|
||||
and the credential type has a workload_identity_token field.
|
||||
|
||||
Args:
|
||||
backend_kwargs: The kwargs dict to pass to the backend (will be modified in place)
|
||||
|
||||
Returns:
|
||||
dict: Response body containing details with the sent JWT payload
|
||||
|
||||
Raises:
|
||||
PermissionDenied: If user lacks access to the job template (re-raised for 403 response)
|
||||
|
||||
All other exceptions are caught and converted to 400 responses with error details.
|
||||
|
||||
Modifies backend_kwargs in place to add workload_identity_token.
|
||||
"""
|
||||
# Validate job template
|
||||
job_template_id = backend_kwargs.pop('job_template_id', None)
|
||||
job_template = self._validate_and_get_job_template(job_template_id)
|
||||
|
||||
# Check user access
|
||||
if not self.request.user.can_access(models.JobTemplate, 'start', job_template):
|
||||
raise PermissionDenied(_('You do not have access to job template with id: %(id)s.') % {'id': job_template.id})
|
||||
|
||||
# Generate workload identity token
|
||||
jwt_token = self._get_workload_identity_token(job_template, backend_kwargs.get('url'))
|
||||
backend_kwargs['workload_identity_token'] = jwt_token
|
||||
|
||||
return {'details': {'sent_jwt_payload': self._decode_jwt_payload_for_display(jwt_token)}}
|
||||
|
||||
def _call_backend_with_error_handling(self, plugin, backend_kwargs, response_body):
|
||||
"""Call credential backend and handle errors."""
|
||||
try:
|
||||
with set_environ(**settings.AWX_TASK_ENV):
|
||||
plugin.backend(**backend_kwargs)
|
||||
return Response(response_body, status=status.HTTP_202_ACCEPTED)
|
||||
except requests.exceptions.HTTPError as exc:
|
||||
message = self._extract_http_error_message(exc)
|
||||
self._add_error_to_response(response_body, message)
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as exc:
|
||||
message = self._extract_generic_error_message(exc)
|
||||
self._add_error_to_response(response_body, message)
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@staticmethod
|
||||
def _extract_http_error_message(exc):
|
||||
"""Extract error message from HTTPError, checking response JSON and text."""
|
||||
message = str(exc)
|
||||
if not hasattr(exc, 'response') or exc.response is None:
|
||||
return message
|
||||
|
||||
try:
|
||||
error_data = exc.response.json()
|
||||
if 'errors' in error_data and error_data['errors']:
|
||||
return ', '.join(error_data['errors'])
|
||||
if 'error' in error_data:
|
||||
return error_data['error']
|
||||
except (ValueError, KeyError):
|
||||
if exc.response.text:
|
||||
return exc.response.text
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _extract_generic_error_message(exc):
|
||||
"""Extract error message from exception, handling ConnectTimeoutError specially."""
|
||||
message = str(exc) if str(exc) else exc.__class__.__name__
|
||||
for arg in getattr(exc, 'args', []):
|
||||
if isinstance(getattr(arg, 'reason', None), ConnectTimeoutError):
|
||||
return str(arg.reason)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _add_error_to_response(response_body, message):
|
||||
"""Add error message to both 'detail' and 'details.error_message' fields."""
|
||||
response_body['detail'] = message
|
||||
if 'details' in response_body:
|
||||
response_body['details']['error_message'] = message
|
||||
|
||||
|
||||
class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
|
||||
"""
|
||||
Test updates to the input values and metadata of an external credential
|
||||
before saving them.
|
||||
@@ -1615,6 +1772,8 @@ class CredentialExternalTest(SubDetailAPIView):
|
||||
It does not support standard credential types such as Machine, SCM, and Cloud."""})
|
||||
def post(self, request, *args, **kwargs):
|
||||
obj = self.get_object()
|
||||
if obj.credential_type.kind != 'external':
|
||||
raise ParseError(_('Credential is not testable.'))
|
||||
backend_kwargs = {}
|
||||
for field_name, value in obj.inputs.items():
|
||||
backend_kwargs[field_name] = obj.get_input(field_name)
|
||||
@@ -1622,23 +1781,22 @@ class CredentialExternalTest(SubDetailAPIView):
|
||||
if value != '$encrypted$':
|
||||
backend_kwargs[field_name] = value
|
||||
backend_kwargs.update(request.data.get('metadata', {}))
|
||||
try:
|
||||
with set_environ(**settings.AWX_TASK_ENV):
|
||||
obj.credential_type.plugin.backend(**backend_kwargs)
|
||||
return Response({}, status=status.HTTP_202_ACCEPTED)
|
||||
except requests.exceptions.HTTPError:
|
||||
message = """Test operation is not supported for credential type {}.
|
||||
This endpoint only supports credentials that connect to
|
||||
external secret management systems such as CyberArk, HashiCorp
|
||||
Vault, or cloud-based secret managers.""".format(obj.credential_type.kind)
|
||||
return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as exc:
|
||||
message = exc.__class__.__name__
|
||||
exc_args = getattr(exc, 'args', [])
|
||||
for a in exc_args:
|
||||
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
|
||||
message = str(a.reason)
|
||||
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Handle OIDC workload identity token generation if enabled
|
||||
response_body = {}
|
||||
if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.credential_type.inputs):
|
||||
try:
|
||||
oidc_response_body = self._handle_oidc_credential_test(backend_kwargs)
|
||||
response_body.update(oidc_response_body)
|
||||
except PermissionDenied:
|
||||
raise
|
||||
except Exception as exc:
|
||||
error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc)
|
||||
response_body['detail'] = error_message
|
||||
response_body['details'] = {'error_message': error_message}
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return self._call_backend_with_error_handling(obj.credential_type.plugin, backend_kwargs, response_body)
|
||||
|
||||
|
||||
class CredentialInputSourceDetail(RetrieveUpdateDestroyAPIView):
|
||||
@@ -1668,7 +1826,7 @@ class CredentialInputSourceSubList(SubListCreateAPIView):
|
||||
parent_key = 'target_credential'
|
||||
|
||||
|
||||
class CredentialTypeExternalTest(SubDetailAPIView):
|
||||
class CredentialTypeExternalTest(OIDCCredentialTestMixin, SubDetailAPIView):
|
||||
"""
|
||||
Test a complete set of input values for an external credential before
|
||||
saving it.
|
||||
@@ -1683,21 +1841,26 @@ class CredentialTypeExternalTest(SubDetailAPIView):
|
||||
@extend_schema_if_available(extensions={"x-ai-description": "Test a complete set of input values for an external credential"})
|
||||
def post(self, request, *args, **kwargs):
|
||||
obj = self.get_object()
|
||||
if obj.kind != 'external':
|
||||
raise ParseError(_('Credential type is not testable.'))
|
||||
backend_kwargs = request.data.get('inputs', {})
|
||||
backend_kwargs.update(request.data.get('metadata', {}))
|
||||
try:
|
||||
obj.plugin.backend(**backend_kwargs)
|
||||
return Response({}, status=status.HTTP_202_ACCEPTED)
|
||||
except requests.exceptions.HTTPError as exc:
|
||||
message = 'HTTP {}'.format(exc.response.status_code)
|
||||
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as exc:
|
||||
message = exc.__class__.__name__
|
||||
args_exc = getattr(exc, 'args', [])
|
||||
for a in args_exc:
|
||||
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
|
||||
message = str(a.reason)
|
||||
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Handle OIDC workload identity token generation if enabled
|
||||
response_body = {}
|
||||
if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.inputs):
|
||||
try:
|
||||
oidc_response_body = self._handle_oidc_credential_test(backend_kwargs)
|
||||
response_body.update(oidc_response_body)
|
||||
except PermissionDenied:
|
||||
raise
|
||||
except Exception as exc:
|
||||
error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc)
|
||||
response_body['detail'] = error_message
|
||||
response_body['details'] = {'error_message': error_message}
|
||||
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return self._call_backend_with_error_handling(obj.plugin, backend_kwargs, response_body)
|
||||
|
||||
|
||||
class HostRelatedSearchMixin(object):
|
||||
@@ -1763,7 +1926,7 @@ class HostList(HostRelatedSearchMixin, ListCreateAPIView):
|
||||
if filter_string:
|
||||
filter_qs = SmartFilter.query_from_string(filter_string)
|
||||
qs &= filter_qs
|
||||
return qs.distinct()
|
||||
return qs.distinct().with_latest_summary_id()
|
||||
|
||||
def list(self, *args, **kwargs):
|
||||
try:
|
||||
@@ -1778,6 +1941,9 @@ class HostDetail(RelatedJobsPreventDeleteMixin, RetrieveUpdateDestroyAPIView):
|
||||
serializer_class = serializers.HostSerializer
|
||||
resource_purpose = 'host detail'
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().with_latest_summary_id()
|
||||
|
||||
@extend_schema_if_available(extensions={"x-ai-description": "Delete a host"})
|
||||
def delete(self, request, *args, **kwargs):
|
||||
if self.get_object().inventory.pending_deletion:
|
||||
@@ -1811,6 +1977,9 @@ class InventoryHostsList(HostRelatedSearchMixin, SubListCreateAttachDetachAPIVie
|
||||
filter_read_permission = False
|
||||
resource_purpose = 'hosts of an inventory'
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().with_latest_summary_id()
|
||||
|
||||
|
||||
class HostGroupsList(SubListCreateAttachDetachAPIView):
|
||||
'''the list of groups a host is directly a member of'''
|
||||
@@ -1994,6 +2163,9 @@ class GroupHostsList(HostRelatedSearchMixin, SubListCreateAttachDetachAPIView):
|
||||
relationship = 'hosts'
|
||||
resource_purpose = 'hosts of a group'
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().with_latest_summary_id()
|
||||
|
||||
def update_raw_data(self, data):
|
||||
data.pop('inventory', None)
|
||||
return super(GroupHostsList, self).update_raw_data(data)
|
||||
@@ -2025,7 +2197,7 @@ class GroupAllHostsList(HostRelatedSearchMixin, SubListAPIView):
|
||||
self.check_parent_access(parent)
|
||||
qs = self.request.user.get_queryset(self.model).distinct() # need distinct for '&' operator
|
||||
sublist_qs = parent.all_hosts.distinct()
|
||||
return qs & sublist_qs
|
||||
return (qs & sublist_qs).with_latest_summary_id()
|
||||
|
||||
|
||||
class GroupInventorySourcesList(SubListAPIView):
|
||||
@@ -2318,6 +2490,9 @@ class InventorySourceHostsList(HostRelatedSearchMixin, SubListDestroyAPIView):
|
||||
check_sub_obj_permission = False
|
||||
resource_purpose = 'hosts of an inventory source'
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().with_latest_summary_id()
|
||||
|
||||
def perform_list_destroy(self, instance_list):
|
||||
inv_source = self.get_parent_object()
|
||||
with ignore_inventory_computed_fields():
|
||||
@@ -4695,19 +4870,12 @@ class RoleUsersList(SubListAttachDetachAPIView):
|
||||
if not sub_id:
|
||||
return super(RoleUsersList, self).post(request)
|
||||
|
||||
user = get_object_or_400(models.User, pk=sub_id)
|
||||
role = self.get_parent_object()
|
||||
|
||||
content_types = ContentType.objects.get_for_models(models.Organization, models.Team, models.Credential) # dict of {model: content_type}
|
||||
credential_content_type = content_types[models.Credential]
|
||||
if role.content_type == credential_content_type:
|
||||
if 'disassociate' not in request.data and role.content_object.organization and user not in role.content_object.organization.member_role:
|
||||
data = dict(msg=_("You cannot grant credential access to a user not in the credentials' organization"))
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not role.content_object.organization and not request.user.is_superuser:
|
||||
data = dict(msg=_("You cannot grant private credential access to another user"))
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
if not request.data.get('disassociate'):
|
||||
user = get_object_or_400(models.User, pk=sub_id)
|
||||
role = self.get_parent_object()
|
||||
content_object = role.content_object
|
||||
if hasattr(content_object, 'validate_role_assignment'):
|
||||
content_object.validate_role_assignment(user, role_definition=None, requesting_user=request.user)
|
||||
|
||||
return super(RoleUsersList, self).post(request, *args, **kwargs)
|
||||
|
||||
@@ -4740,24 +4908,6 @@ class RoleTeamsList(SubListAttachDetachAPIView):
|
||||
data = dict(msg=_("You cannot assign an Organization participation role as a child role for a Team."))
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
credential_content_type = ContentType.objects.get_for_model(models.Credential)
|
||||
if role.content_type == credential_content_type:
|
||||
# Private credentials (no organization) are never allowed for teams
|
||||
if not role.content_object.organization:
|
||||
data = dict(
|
||||
msg=_("You cannot grant access to a credential that is not assigned to an organization (private credentials cannot be assigned to teams)")
|
||||
)
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
# Cross-organization credentials are only allowed for superusers
|
||||
elif role.content_object.organization.id != team.organization.id:
|
||||
if not request.user.is_superuser:
|
||||
data = dict(
|
||||
msg=_(
|
||||
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
|
||||
)
|
||||
)
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
action = 'attach'
|
||||
if request.data.get('disassociate', None):
|
||||
action = 'unattach'
|
||||
@@ -4766,6 +4916,11 @@ class RoleTeamsList(SubListAttachDetachAPIView):
|
||||
data = dict(msg=_("You cannot grant system-level permissions to a team."))
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if action == 'attach':
|
||||
content_object = role.content_object
|
||||
if hasattr(content_object, 'validate_role_assignment'):
|
||||
content_object.validate_role_assignment(team, role_definition=None, requesting_user=request.user)
|
||||
|
||||
if not request.user.can_access(self.parent_model, action, role, team, self.relationship, request.data, skip_sub_obj_read_check=False):
|
||||
raise PermissionDenied()
|
||||
if request.data.get('disassociate', None):
|
||||
|
||||
@@ -49,7 +49,6 @@ class GetNotAllowedMixin(object):
|
||||
class AnalyticsRootView(APIView):
|
||||
permission_classes = (AnalyticsPermission,)
|
||||
name = _('Automation Analytics')
|
||||
swagger_topic = 'Automation Analytics'
|
||||
resource_purpose = 'automation analytics endpoints'
|
||||
|
||||
@extend_schema_if_available(extensions={"x-ai-description": "A list of additional API endpoints related to analytics"})
|
||||
@@ -306,7 +305,6 @@ class AnalyticsAuthorizedView(AnalyticsGenericListView):
|
||||
|
||||
class AnalyticsReportsList(GetNotAllowedMixin, AnalyticsGenericListView):
|
||||
name = _("Reports")
|
||||
swagger_topic = "Automation Analytics"
|
||||
resource_purpose = 'automation analytics reports'
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
import dateutil
|
||||
import logging
|
||||
|
||||
from django.db.models import Count
|
||||
from django.db.models import Count, IntegerField, OuterRef, Subquery
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db import transaction
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.timezone import now
|
||||
@@ -15,7 +16,7 @@ from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
|
||||
from awx.main.constants import ACTIVE_STATES
|
||||
from awx.main.models import Organization
|
||||
from awx.main.models import Organization, Role
|
||||
from awx.main.utils import get_object_or_400
|
||||
from awx.main.models.ha import Instance, InstanceGroup, schedule_policy_task
|
||||
from awx.main.models.organization import Team
|
||||
@@ -178,9 +179,22 @@ class OrganizationCountsMixin(object):
|
||||
db_results['projects'] = project_qs.values('organization').annotate(Count('organization')).order_by('organization')
|
||||
|
||||
# Other members and admins of organization are always viewable
|
||||
db_results['users'] = org_qs.annotate(users=Count('member_role__members', distinct=True), admins=Count('admin_role__members', distinct=True)).values(
|
||||
'id', 'users', 'admins'
|
||||
#
|
||||
# Use independent subqueries instead of double-JOIN Count to avoid
|
||||
# cartesian product.
|
||||
RoleMember = Role.members.through
|
||||
member_count = Subquery(
|
||||
RoleMember.objects.filter(role_id=OuterRef('member_role_id')).values('role_id').annotate(cnt=Count('user_id', distinct=True)).values('cnt'),
|
||||
output_field=IntegerField(),
|
||||
)
|
||||
admin_count = Subquery(
|
||||
RoleMember.objects.filter(role_id=OuterRef('admin_role_id')).values('role_id').annotate(cnt=Count('user_id', distinct=True)).values('cnt'),
|
||||
output_field=IntegerField(),
|
||||
)
|
||||
db_results['users'] = org_qs.annotate(
|
||||
users=Coalesce(member_count, 0),
|
||||
admins=Coalesce(admin_count, 0),
|
||||
).values('id', 'users', 'admins')
|
||||
|
||||
count_context = {}
|
||||
for org in org_id_list:
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
import logging
|
||||
|
||||
# Django
|
||||
from django.db.models import Count
|
||||
from django.db.models import Count, IntegerField, OuterRef, Subquery
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
@@ -77,10 +78,19 @@ class OrganizationDetail(RelatedJobsPreventDeleteMixin, RetrieveUpdateDestroyAPI
|
||||
|
||||
org_counts = {}
|
||||
access_kwargs = {'accessor': self.request.user, 'role_field': 'read_role'}
|
||||
# Use independent subqueries instead of double-JOIN Count to avoid
|
||||
# cartesian product.
|
||||
RoleMember = Role.members.through
|
||||
member_count = Subquery(
|
||||
RoleMember.objects.filter(role_id=OuterRef('member_role_id')).values('role_id').annotate(cnt=Count('user_id', distinct=True)).values('cnt'),
|
||||
output_field=IntegerField(),
|
||||
)
|
||||
admin_count = Subquery(
|
||||
RoleMember.objects.filter(role_id=OuterRef('admin_role_id')).values('role_id').annotate(cnt=Count('user_id', distinct=True)).values('cnt'),
|
||||
output_field=IntegerField(),
|
||||
)
|
||||
direct_counts = (
|
||||
Organization.objects.filter(id=org_id)
|
||||
.annotate(users=Count('member_role__members', distinct=True), admins=Count('admin_role__members', distinct=True))
|
||||
.values('users', 'admins')
|
||||
Organization.objects.filter(id=org_id).annotate(users=Coalesce(member_count, 0), admins=Coalesce(admin_count, 0)).values('users', 'admins')
|
||||
)
|
||||
|
||||
if not direct_counts:
|
||||
|
||||
@@ -344,13 +344,22 @@ class ApiV2ConfigView(APIView):
|
||||
become_methods=PRIVILEGE_ESCALATION_METHODS,
|
||||
)
|
||||
|
||||
if (
|
||||
request.user.is_superuser
|
||||
or request.user.is_system_auditor
|
||||
or Organization.accessible_objects(request.user, 'admin_role').exists()
|
||||
or Organization.accessible_objects(request.user, 'auditor_role').exists()
|
||||
or Organization.accessible_objects(request.user, 'project_admin_role').exists()
|
||||
):
|
||||
# Check superuser/auditor first
|
||||
if request.user.is_superuser or request.user.is_system_auditor:
|
||||
has_org_access = True
|
||||
else:
|
||||
# Single query checking all three organization role types at once
|
||||
has_org_access = (
|
||||
(
|
||||
Organization.access_qs(request.user, 'change')
|
||||
| Organization.access_qs(request.user, 'audit')
|
||||
| Organization.access_qs(request.user, 'add_project')
|
||||
)
|
||||
.distinct()
|
||||
.exists()
|
||||
)
|
||||
|
||||
if has_org_access:
|
||||
data.update(
|
||||
dict(
|
||||
project_base_dir=settings.PROJECTS_ROOT,
|
||||
@@ -358,8 +367,10 @@ class ApiV2ConfigView(APIView):
|
||||
custom_virtualenvs=get_custom_venv_choices(),
|
||||
)
|
||||
)
|
||||
elif JobTemplate.accessible_objects(request.user, 'admin_role').exists():
|
||||
data['custom_virtualenvs'] = get_custom_venv_choices()
|
||||
else:
|
||||
# Only check JobTemplate access if org check failed
|
||||
if JobTemplate.accessible_objects(request.user, 'admin_role').exists():
|
||||
data['custom_virtualenvs'] = get_custom_venv_choices()
|
||||
|
||||
return Response(data)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from awx.api import serializers
|
||||
from awx.api.generics import APIView, GenericAPIView
|
||||
from awx.api.permissions import WebhookKeyPermission
|
||||
from awx.main.models import Job, JobTemplate, WorkflowJob, WorkflowJobTemplate
|
||||
from awx.main.constants import JOB_VARIABLE_PREFIXES
|
||||
from awx.main.utils.common import get_job_variable_prefixes
|
||||
|
||||
logger = logging.getLogger('awx.api.views.webhooks')
|
||||
|
||||
@@ -166,7 +166,7 @@ class WebhookReceiverBase(APIView):
|
||||
'extra_vars': {},
|
||||
}
|
||||
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
kwargs['extra_vars']['{}_webhook_event_type'.format(name)] = event_type
|
||||
kwargs['extra_vars']['{}_webhook_event_guid'.format(name)] = event_guid
|
||||
kwargs['extra_vars']['{}_webhook_event_ref'.format(name)] = event_ref
|
||||
|
||||
@@ -897,8 +897,6 @@ class HostAccess(BaseAccess):
|
||||
'created_by',
|
||||
'modified_by',
|
||||
'inventory',
|
||||
'last_job__job_template',
|
||||
'last_job_host_summary__job',
|
||||
)
|
||||
prefetch_related = ('groups', 'inventory_sources')
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import pathlib
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
@@ -23,6 +24,8 @@ from awx.main.models import Job
|
||||
from awx.main.access import access_registry
|
||||
from awx.main.utils import get_awx_http_client_headers, set_environ, datetime_hook
|
||||
from awx.main.utils.analytics_proxy import OIDCClient
|
||||
from awx.main.utils.candlepin import get_or_generate_candlepin_certificate
|
||||
from awx.main.utils.candlepin.client import _temp_cert_files
|
||||
|
||||
__all__ = ['register', 'gather', 'ship']
|
||||
|
||||
@@ -41,6 +44,76 @@ def _valid_license():
|
||||
return True
|
||||
|
||||
|
||||
def _get_cert_upload_url(url):
|
||||
"""
|
||||
Convert analytics URL to use 'cert.' subdomain for mTLS uploads.
|
||||
|
||||
Some analytics services use different hostnames for different auth methods:
|
||||
- cert.example.com - for mTLS (certificate-based) uploads
|
||||
- example.com - for OIDC (token-based) uploads
|
||||
|
||||
Args:
|
||||
url: Original analytics URL
|
||||
|
||||
Returns:
|
||||
URL with 'cert.' prepended to hostname if not already present
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
|
||||
# Only modify if hostname doesn't already start with 'cert.'
|
||||
if hostname and not hostname.startswith('cert.'):
|
||||
new_hostname = f'cert.{hostname}'
|
||||
# Reconstruct URL with new hostname
|
||||
netloc = new_hostname
|
||||
if parsed.port:
|
||||
netloc = f'{new_hostname}:{parsed.port}'
|
||||
|
||||
new_parsed = parsed._replace(netloc=netloc)
|
||||
return urlunparse(new_parsed)
|
||||
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not modify URL for cert upload: {e}, using original URL')
|
||||
return url
|
||||
|
||||
|
||||
def _get_analytics_credentials():
|
||||
"""
|
||||
Get Red Hat Insights credentials from settings.
|
||||
|
||||
Attempts to retrieve credentials in the following priority order:
|
||||
1. REDHAT_USERNAME / REDHAT_PASSWORD
|
||||
2. SUBSCRIPTIONS_USERNAME / SUBSCRIPTIONS_PASSWORD
|
||||
3. SUBSCRIPTIONS_CLIENT_ID / SUBSCRIPTIONS_CLIENT_SECRET
|
||||
|
||||
Returns:
|
||||
tuple: (username, password) if credentials are found, (None, None) otherwise
|
||||
"""
|
||||
rh_id = getattr(settings, 'REDHAT_USERNAME', None)
|
||||
rh_secret = getattr(settings, 'REDHAT_PASSWORD', None)
|
||||
|
||||
if rh_id and rh_secret:
|
||||
return rh_id, rh_secret
|
||||
|
||||
# Try SUBSCRIPTIONS_USERNAME / SUBSCRIPTIONS_PASSWORD
|
||||
rh_id = getattr(settings, 'SUBSCRIPTIONS_USERNAME', None)
|
||||
rh_secret = getattr(settings, 'SUBSCRIPTIONS_PASSWORD', None)
|
||||
|
||||
if rh_id and rh_secret:
|
||||
return rh_id, rh_secret
|
||||
|
||||
# Try SUBSCRIPTIONS_CLIENT_ID / SUBSCRIPTIONS_CLIENT_SECRET
|
||||
rh_id = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None)
|
||||
rh_secret = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None)
|
||||
|
||||
if rh_id and rh_secret:
|
||||
return rh_id, rh_secret
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def all_collectors():
|
||||
from awx.main.analytics import collectors
|
||||
|
||||
@@ -184,10 +257,8 @@ def gather(dest=None, module=None, subset=None, since=None, until=None, collecti
|
||||
logger.log(log_level, "Automation Analytics not enabled. Use --dry-run to gather locally without sending.")
|
||||
return None
|
||||
|
||||
if not (
|
||||
settings.AUTOMATION_ANALYTICS_URL
|
||||
and ((settings.REDHAT_USERNAME and settings.REDHAT_PASSWORD) or (settings.SUBSCRIPTIONS_CLIENT_ID and settings.SUBSCRIPTIONS_CLIENT_SECRET))
|
||||
):
|
||||
rh_id, rh_secret = _get_analytics_credentials()
|
||||
if not (settings.AUTOMATION_ANALYTICS_URL and rh_id and rh_secret):
|
||||
logger.log(log_level, "Not gathering analytics, configuration is invalid. Use --dry-run to gather locally without sending.")
|
||||
return None
|
||||
|
||||
@@ -368,19 +439,14 @@ def ship(path):
|
||||
logger.error('AUTOMATION_ANALYTICS_URL is not set')
|
||||
return False
|
||||
|
||||
rh_id = getattr(settings, 'REDHAT_USERNAME', None)
|
||||
rh_secret = getattr(settings, 'REDHAT_PASSWORD', None)
|
||||
|
||||
if not (rh_id and rh_secret):
|
||||
rh_id = getattr(settings, 'SUBSCRIPTIONS_CLIENT_ID', None)
|
||||
rh_secret = getattr(settings, 'SUBSCRIPTIONS_CLIENT_SECRET', None)
|
||||
rh_id, rh_secret = _get_analytics_credentials()
|
||||
|
||||
if not rh_id:
|
||||
logger.error('Neither REDHAT_USERNAME nor SUBSCRIPTIONS_CLIENT_ID are set')
|
||||
logger.error('No valid username found. Tried: REDHAT_USERNAME, SUBSCRIPTIONS_USERNAME, SUBSCRIPTIONS_CLIENT_ID')
|
||||
return False
|
||||
|
||||
if not rh_secret:
|
||||
logger.error('Neither REDHAT_PASSWORD nor SUBSCRIPTIONS_CLIENT_SECRET are set')
|
||||
logger.error('No valid password found. Tried: REDHAT_PASSWORD, SUBSCRIPTIONS_PASSWORD, SUBSCRIPTIONS_CLIENT_SECRET')
|
||||
return False
|
||||
|
||||
with open(path, 'rb') as f:
|
||||
@@ -388,17 +454,40 @@ def ship(path):
|
||||
s = requests.Session()
|
||||
s.headers = get_awx_http_client_headers()
|
||||
s.headers.pop('Content-Type')
|
||||
|
||||
with set_environ(**settings.AWX_TASK_ENV):
|
||||
# Try Certificate-based mTLS authentication (zero-touch)
|
||||
cert_pem, key_pem = get_or_generate_candlepin_certificate()
|
||||
if cert_pem and key_pem:
|
||||
# Use cert. subdomain for mTLS uploads
|
||||
cert_url = _get_cert_upload_url(url)
|
||||
logger.debug("Attempting certificate-based authentication for analytics upload")
|
||||
try:
|
||||
with _temp_cert_files(cert_pem, key_pem) as (cert_path, key_path):
|
||||
response = s.post(
|
||||
cert_url, files=files, cert=(cert_path, key_path), verify=settings.INSIGHTS_CERT_PATH, headers=s.headers, timeout=(31, 31)
|
||||
)
|
||||
if response.status_code < 300:
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f'Certificate-based authentication failed with status {response.status_code}, {response.text}. Falling back to OIDC auth'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Certificate-based authentication failed: {e}, falling back to OIDC auth")
|
||||
|
||||
# Try OIDC authentication
|
||||
logger.debug("Attempting OIDC authentication for analytics upload")
|
||||
f.seek(0) # requests POST may read from the handler, so seek to beginning of file for the next POST attempt
|
||||
try:
|
||||
client = OIDCClient(rh_id, rh_secret)
|
||||
response = client.make_request("POST", url, headers=s.headers, files=files, verify=settings.INSIGHTS_CERT_PATH, timeout=(31, 31))
|
||||
except requests.RequestException:
|
||||
logger.error("Automation Analytics API request failed, trying base auth method")
|
||||
response = s.post(url, files=files, verify=settings.INSIGHTS_CERT_PATH, auth=(rh_id, rh_secret), headers=s.headers, timeout=(31, 31))
|
||||
|
||||
# Accept 2XX status_codes
|
||||
if response.status_code >= 300:
|
||||
logger.error('Upload failed with status {}, {}'.format(response.status_code, response.text))
|
||||
return False
|
||||
|
||||
return True
|
||||
if response.status_code < 300:
|
||||
return True
|
||||
else:
|
||||
logger.error(f'OIDC authentication failed with status {response.status_code}, {response.text}')
|
||||
return False
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"OIDC authentication failed: {e}")
|
||||
return False
|
||||
|
||||
@@ -1,22 +1,25 @@
|
||||
import os
|
||||
|
||||
from dispatcherd.config import setup as dispatcher_setup
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.db import connection
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from awx.main.utils.common import bypass_in_test, load_all_entry_points_for
|
||||
from awx.main.utils.migration import is_database_synchronized
|
||||
from awx.main.utils.named_url_graph import _customize_graph, generate_graph
|
||||
from awx.conf import register, fields
|
||||
from django.core.management.base import CommandError
|
||||
from django.db.models.signals import pre_migrate
|
||||
|
||||
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name
|
||||
from awx.main.utils.named_url_graph import _customize_graph, generate_graph
|
||||
from awx.main.utils.db import db_requirement_violations
|
||||
from awx.conf import register, fields
|
||||
|
||||
|
||||
class MainConfig(AppConfig):
|
||||
name = 'awx.main'
|
||||
verbose_name = _('Main')
|
||||
|
||||
def check_db_requirement(self, *args, **kwargs):
|
||||
violations = db_requirement_violations()
|
||||
if violations:
|
||||
raise CommandError(violations)
|
||||
|
||||
def load_named_url_feature(self):
|
||||
models = [m for m in self.get_models() if hasattr(m, 'get_absolute_url')]
|
||||
generate_graph(models)
|
||||
@@ -43,42 +46,6 @@ class MainConfig(AppConfig):
|
||||
category_slug='named-url',
|
||||
)
|
||||
|
||||
def _load_credential_types_feature(self):
|
||||
"""
|
||||
Create CredentialType records for any discovered credentials.
|
||||
|
||||
Note that Django docs advise _against_ interacting with the database using
|
||||
the ORM models in the ready() path. Specifically, during testing.
|
||||
However, we explicitly use the @bypass_in_test decorator to avoid calling this
|
||||
method during testing.
|
||||
|
||||
Django also advises against running pattern because it runs everywhere i.e.
|
||||
every management command. We use an advisory lock to ensure correctness and
|
||||
we will deal performance if it becomes an issue.
|
||||
"""
|
||||
from awx.main.models.credential import CredentialType
|
||||
|
||||
if is_database_synchronized():
|
||||
CredentialType.setup_tower_managed_defaults(app_config=self)
|
||||
|
||||
@bypass_in_test
|
||||
def load_credential_types_feature(self):
|
||||
from awx.main.models.credential import load_credentials
|
||||
|
||||
load_credentials()
|
||||
return self._load_credential_types_feature()
|
||||
|
||||
def load_inventory_plugins(self):
|
||||
from awx.main.models.inventory import InventorySourceOptions
|
||||
|
||||
is_awx = detect_server_product_name() == 'AWX'
|
||||
extra_entry_point_groups = () if is_awx else ('inventory.supported',)
|
||||
entry_points = load_all_entry_points_for(['inventory', *extra_entry_point_groups])
|
||||
|
||||
for entry_point_name, entry_point in entry_points.items():
|
||||
cls = entry_point.load()
|
||||
InventorySourceOptions.injectors[entry_point_name] = cls
|
||||
|
||||
def configure_dispatcherd(self):
|
||||
"""This implements the default configuration for dispatcherd
|
||||
|
||||
@@ -100,13 +67,5 @@ class MainConfig(AppConfig):
|
||||
super().ready()
|
||||
|
||||
self.configure_dispatcherd()
|
||||
|
||||
"""
|
||||
Credential loading triggers database operations. There are cases we want to call
|
||||
awx-manage collectstatic without a database. All management commands invoke the ready() code
|
||||
path. Using settings.AWX_SKIP_CREDENTIAL_TYPES_DISCOVER _could_ invoke a database operation.
|
||||
"""
|
||||
if not os.environ.get('AWX_SKIP_CREDENTIAL_TYPES_DISCOVER', None):
|
||||
self.load_credential_types_feature()
|
||||
self.load_named_url_feature()
|
||||
self.load_inventory_plugins()
|
||||
pre_migrate.connect(self.check_db_requirement, sender=self)
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import functools
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache.backends.base import DEFAULT_TIMEOUT
|
||||
from django.core.cache.backends.redis import RedisCache
|
||||
|
||||
from redis.exceptions import ConnectionError, ResponseError, TimeoutError
|
||||
import socket
|
||||
|
||||
# This list comes from what django-redis ignores and the behavior we are trying
|
||||
# to retain while dropping the dependency on django-redis.
|
||||
IGNORED_EXCEPTIONS = (TimeoutError, ResponseError, ConnectionError, socket.timeout)
|
||||
|
||||
CONNECTION_INTERRUPTED_SENTINEL = object()
|
||||
|
||||
|
||||
def optionally_ignore_exceptions(func=None, return_value=None):
|
||||
if func is None:
|
||||
return functools.partial(optionally_ignore_exceptions, return_value=return_value)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except IGNORED_EXCEPTIONS as e:
|
||||
if settings.DJANGO_REDIS_IGNORE_EXCEPTIONS:
|
||||
return return_value
|
||||
raise e.__cause__ or e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AWXRedisCache(RedisCache):
|
||||
"""
|
||||
We just want to wrap the upstream RedisCache class so that we can ignore
|
||||
the exceptions that it raises when the cache is unavailable.
|
||||
"""
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
|
||||
return super().add(key, value, timeout, version)
|
||||
|
||||
@optionally_ignore_exceptions(return_value=CONNECTION_INTERRUPTED_SENTINEL)
|
||||
def _get(self, key, default=None, version=None):
|
||||
return super().get(key, default, version)
|
||||
|
||||
def get(self, key, default=None, version=None):
|
||||
value = self._get(key, default, version)
|
||||
if value is CONNECTION_INTERRUPTED_SENTINEL:
|
||||
return default
|
||||
return value
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
|
||||
return super().set(key, value, timeout, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
|
||||
return super().touch(key, timeout, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def delete(self, key, version=None):
|
||||
return super().delete(key, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def get_many(self, keys, version=None):
|
||||
return super().get_many(keys, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def has_key(self, key, version=None):
|
||||
return super().has_key(key, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def incr(self, key, delta=1, version=None):
|
||||
return super().incr(key, delta, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
|
||||
return super().set_many(data, timeout, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def delete_many(self, keys, version=None):
|
||||
return super().delete_many(keys, version)
|
||||
|
||||
@optionally_ignore_exceptions
|
||||
def clear(self):
|
||||
return super().clear()
|
||||
102
awx/main/conf.py
102
awx/main/conf.py
@@ -213,6 +213,40 @@ register(
|
||||
category_slug='system',
|
||||
)
|
||||
|
||||
register(
|
||||
'AWX_ANALYTICS_CANDLEPIN_CA',
|
||||
field_class=fields.CharField,
|
||||
default='/etc/rhsm/ca/redhat-uep.pem',
|
||||
allow_blank=True,
|
||||
label=_('Candlepin CA Certificate Path'),
|
||||
help_text=_('Path to the CA certificate file for verifying TLS connections to Candlepin. Leave blank to use system certificates.'),
|
||||
category=_('System'),
|
||||
category_slug='system',
|
||||
)
|
||||
|
||||
register(
|
||||
'AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS',
|
||||
field_class=fields.IntegerField,
|
||||
default=90,
|
||||
min_value=1,
|
||||
label=_('Candlepin Certificate Renewal Threshold'),
|
||||
help_text=_('Number of days before certificate expiry to trigger automatic renewal of Candlepin identity certificates.'),
|
||||
category=_('System'),
|
||||
category_slug='system',
|
||||
unit=_('days'),
|
||||
)
|
||||
|
||||
register(
|
||||
'AWX_ANALYTICS_CANDLEPIN_PROXY_URL',
|
||||
field_class=fields.CharField,
|
||||
default='',
|
||||
allow_blank=True,
|
||||
label=_('Candlepin Proxy URL'),
|
||||
help_text=_('HTTP/HTTPS proxy URL for Candlepin API requests (e.g., http://proxy.example.com:8080). Leave blank for no proxy.'),
|
||||
category=_('System'),
|
||||
category_slug='system',
|
||||
)
|
||||
|
||||
register(
|
||||
'INSTALL_UUID',
|
||||
field_class=fields.CharField,
|
||||
@@ -291,6 +325,22 @@ register(
|
||||
category_slug='jobs',
|
||||
)
|
||||
|
||||
register(
|
||||
'INCLUDE_DEPRECATED_AWX_VAR_PREFIX',
|
||||
field_class=fields.BooleanField,
|
||||
default=True,
|
||||
label=_('Include Deprecated AWX Variable Prefix'),
|
||||
help_text=_(
|
||||
'When enabled (default), auto-generated job variables are emitted '
|
||||
'with both the tower_ prefix and the deprecated awx_ prefix for '
|
||||
'backward compatibility. Disable to emit only tower_ prefixed '
|
||||
'variables and eliminate duplicates. The awx_ prefix is deprecated '
|
||||
'and this setting will default to False in a future release.'
|
||||
),
|
||||
category=_('Jobs'),
|
||||
category_slug='jobs',
|
||||
)
|
||||
|
||||
register(
|
||||
'AWX_ISOLATION_BASE_PATH',
|
||||
field_class=fields.CharField,
|
||||
@@ -824,6 +874,58 @@ register(
|
||||
unit=_('seconds'),
|
||||
)
|
||||
|
||||
register(
|
||||
'CANDLEPIN_CONSUMER_UUID',
|
||||
field_class=fields.CharField,
|
||||
default='',
|
||||
allow_blank=True,
|
||||
encrypted=False,
|
||||
label=_('Candlepin Consumer UUID'),
|
||||
help_text=_('UUID of the registered Candlepin consumer for this AAP instance.'),
|
||||
category=_('System'),
|
||||
category_slug='system',
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
register(
|
||||
'CANDLEPIN_CERT_PEM',
|
||||
field_class=fields.CharField,
|
||||
default='',
|
||||
allow_blank=True,
|
||||
encrypted=True,
|
||||
label=_('Candlepin Identity Certificate'),
|
||||
help_text=_('PEM-encoded Candlepin identity certificate for mTLS authentication.'),
|
||||
category=_('System'),
|
||||
category_slug='system',
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
register(
|
||||
'CANDLEPIN_KEY_PEM',
|
||||
field_class=fields.CharField,
|
||||
default='',
|
||||
allow_blank=True,
|
||||
encrypted=True,
|
||||
label=_('Candlepin Identity Key'),
|
||||
help_text=_('PEM-encoded private key for Candlepin identity certificate.'),
|
||||
category=_('System'),
|
||||
category_slug='system',
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
register(
|
||||
'CANDLEPIN_SERIAL_NUMBER',
|
||||
field_class=fields.CharField,
|
||||
default='',
|
||||
allow_blank=True,
|
||||
encrypted=False,
|
||||
label=_('Candlepin Certificate Serial Number'),
|
||||
help_text=_('Serial number of the Candlepin identity certificate for tracking.'),
|
||||
category=_('System'),
|
||||
category_slug='system',
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
register(
|
||||
'IS_K8S',
|
||||
field_class=fields.BooleanField,
|
||||
|
||||
@@ -100,10 +100,6 @@ MAX_ISOLATED_PATH_COLON_DELIMITER = 2
|
||||
|
||||
SURVEY_TYPE_MAPPING = {'text': str, 'textarea': str, 'password': str, 'multiplechoice': str, 'multiselect': str, 'integer': int, 'float': (float, int)}
|
||||
|
||||
JOB_VARIABLE_PREFIXES = [
|
||||
'awx',
|
||||
'tower',
|
||||
]
|
||||
|
||||
# Note, the \u001b[... are ansi color codes. We don't currenly import any of the python modules which define the codes.
|
||||
# Importing a library just for this message seemed like overkill
|
||||
|
||||
@@ -25,12 +25,13 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
|
||||
"version": 2,
|
||||
"service": {
|
||||
"pool_kwargs": {
|
||||
"min_workers": settings.JOB_EVENT_WORKERS,
|
||||
"min_workers": settings.DISPATCHER_MIN_WORKERS,
|
||||
"max_workers": max_workers,
|
||||
# This must be less than max_workers to make sense, which is usually 4
|
||||
# With reserve of 1, after a burst of tasks, load needs to down to 4-1=3
|
||||
# before we return to min_workers
|
||||
"scaledown_reserve": 1,
|
||||
"worker_max_lifetime_seconds": settings.WORKER_MAX_LIFETIME_SECONDS,
|
||||
},
|
||||
"main_kwargs": {"node_id": settings.CLUSTER_HOST_ID},
|
||||
"process_manager_cls": "ForkServerManager",
|
||||
|
||||
330
awx/main/management/commands/candlepin_cert.py
Normal file
330
awx/main/management/commands/candlepin_cert.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import sys
|
||||
|
||||
from argparse import RawDescriptionHelpFormatter
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from awx.main.utils.candlepin.client import CandlepinClient
|
||||
from awx.main.utils.candlepin.lifecycle import (
|
||||
get_candlepin_ca,
|
||||
get_candlepin_url,
|
||||
get_proxy_url,
|
||||
get_renewal_days,
|
||||
needs_renewal,
|
||||
parse_cert,
|
||||
)
|
||||
from awx.main.utils.candlepin import (
|
||||
_fetch_candlepin_cert_from_db,
|
||||
_save_candlepin_cert_to_db,
|
||||
_save_candlepin_registration_to_db,
|
||||
resolve_registration_credentials,
|
||||
)
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""
|
||||
Manage Candlepin consumer registration and certificate lifecycle.
|
||||
|
||||
Subcommands:
|
||||
register Register this AAP instance as a Candlepin consumer and obtain an
|
||||
identity certificate for mTLS analytics uploads.
|
||||
renew Perform a manual check-in and, if needed, renew the stored identity
|
||||
certificate.
|
||||
"""
|
||||
|
||||
help = 'Manage Candlepin consumer registration and certificate lifecycle'
|
||||
|
||||
def create_parser(self, prog_name, subcommand, **kwargs):
|
||||
return super().create_parser(
|
||||
prog_name,
|
||||
subcommand,
|
||||
formatter_class=RawDescriptionHelpFormatter,
|
||||
epilog='\n'.join(
|
||||
[
|
||||
'SUBCOMMANDS',
|
||||
'',
|
||||
' register Register this instance as a Candlepin consumer.',
|
||||
' Credentials are read from AWX database by default',
|
||||
' (REDHAT_USERNAME, REDHAT_PASSWORD). The organization is',
|
||||
' discovered automatically from the Candlepin account.',
|
||||
' Pass --username / --password-stdin / --org to override.',
|
||||
' Example: echo "password" | awx-manage candlepin_cert register --username user --password-stdin',
|
||||
'',
|
||||
' renew Perform a manual check-in and proactive cert renewal.',
|
||||
' Reads the stored cert/key/UUID from database.',
|
||||
' Use --force to renew even if the cert is not near expiry.',
|
||||
'',
|
||||
'CONFIGURATION',
|
||||
'',
|
||||
' Settings can be configured via Django settings (awx/settings/defaults.py):',
|
||||
'',
|
||||
' AWX_ANALYTICS_CANDLEPIN_URL Candlepin base URL',
|
||||
' (default: https://subscription.example.com/candlepin)',
|
||||
' AWX_ANALYTICS_CANDLEPIN_CA Path to Candlepin CA cert for TLS verification',
|
||||
' AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS Days before expiry to trigger renewal (default: 90)',
|
||||
' AWX_ANALYTICS_CANDLEPIN_PROXY_URL HTTP/HTTPS proxy for Candlepin API calls',
|
||||
]
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def add_arguments(self, parser):
|
||||
subparsers = parser.add_subparsers(dest='subcommand', metavar='subcommand')
|
||||
subparsers.required = True
|
||||
|
||||
# --- register ---
|
||||
reg = subparsers.add_parser(
|
||||
'register',
|
||||
help='Register this instance as a Candlepin consumer',
|
||||
formatter_class=RawDescriptionHelpFormatter,
|
||||
)
|
||||
reg.add_argument('--username', help='Red Hat subscription username (overrides REDHAT_USERNAME from database)')
|
||||
reg.add_argument(
|
||||
'--password-stdin', dest='password_stdin', action='store_true', help='Read password from stdin (overrides REDHAT_PASSWORD from database)'
|
||||
)
|
||||
reg.add_argument('--org', help='Candlepin owner/org key (overrides auto-discovered organization)')
|
||||
reg.add_argument('--candlepin-url', dest='candlepin_url', help='Candlepin base URL (overrides AWX_ANALYTICS_CANDLEPIN_URL setting)')
|
||||
reg.add_argument(
|
||||
'--candlepin-ca', dest='candlepin_ca', help='Path to Candlepin CA cert for TLS verification (overrides AWX_ANALYTICS_CANDLEPIN_CA setting)'
|
||||
)
|
||||
reg.add_argument('--proxy', help='HTTP/HTTPS proxy URL (overrides AWX_ANALYTICS_CANDLEPIN_PROXY_URL setting)')
|
||||
reg.add_argument('--no-verify-tls', dest='no_verify_tls', action='store_true', help='Disable TLS certificate verification for Candlepin API calls')
|
||||
reg.add_argument('--force', action='store_true', help='Re-register even if a certificate already exists in database')
|
||||
reg.add_argument('--dry-run', dest='dry_run', action='store_true', help='Perform registration but do not save the result to database')
|
||||
|
||||
# --- renew ---
|
||||
ren = subparsers.add_parser(
|
||||
'renew',
|
||||
help='Check in and renew the Candlepin identity certificate',
|
||||
formatter_class=RawDescriptionHelpFormatter,
|
||||
)
|
||||
ren.add_argument('--candlepin-url', dest='candlepin_url', help='Candlepin base URL (overrides AWX_ANALYTICS_CANDLEPIN_URL setting)')
|
||||
ren.add_argument(
|
||||
'--candlepin-ca', dest='candlepin_ca', help='Path to Candlepin CA cert for TLS verification (overrides AWX_ANALYTICS_CANDLEPIN_CA setting)'
|
||||
)
|
||||
ren.add_argument('--proxy', help='HTTP/HTTPS proxy URL (overrides AWX_ANALYTICS_CANDLEPIN_PROXY_URL setting)')
|
||||
ren.add_argument('--no-verify-tls', dest='no_verify_tls', action='store_true', help='Disable TLS certificate verification for Candlepin API calls')
|
||||
ren.add_argument('--force', action='store_true', help='Renew the certificate even if it is not near expiry')
|
||||
ren.add_argument('--dry-run', dest='dry_run', action='store_true', help='Perform check-in and renewal but do not save the result to database')
|
||||
|
||||
def handle(self, *args, **options):
|
||||
subcommand = options['subcommand']
|
||||
if subcommand == 'register':
|
||||
ok = self._handle_register(options)
|
||||
elif subcommand == 'renew':
|
||||
ok = self._handle_renew(options)
|
||||
else:
|
||||
self.stderr.write(f'Unknown subcommand: {subcommand}')
|
||||
sys.exit(1)
|
||||
|
||||
if not ok:
|
||||
sys.exit(1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# register
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_and_validate_credentials(self, options):
|
||||
"""Merge CLI options with DB values and validate all required fields are present.
|
||||
|
||||
Returns ``(username, password, org, db_install_uuid)`` on success, or ``None``
|
||||
if any required field is missing (errors are written to ``self.stderr``).
|
||||
"""
|
||||
username_override = options.get('username')
|
||||
org_override = options.get('org')
|
||||
verify_tls = not options.get('no_verify_tls', False)
|
||||
|
||||
# Read password from stdin if --password-stdin is set
|
||||
if options.get('password_stdin'):
|
||||
password_override = sys.stdin.read().strip()
|
||||
if not password_override:
|
||||
self.stderr.write('--password-stdin specified but no password provided on stdin')
|
||||
return None
|
||||
else:
|
||||
password_override = None
|
||||
|
||||
# Use shared resolution and validation function
|
||||
username, password, org, install_uuid, errors = resolve_registration_credentials(
|
||||
username_override=username_override, password_override=password_override, org_override=org_override, verify_tls=verify_tls
|
||||
)
|
||||
|
||||
if errors:
|
||||
for error in errors:
|
||||
self.stderr.write(f'Missing required value: {error}')
|
||||
return None
|
||||
|
||||
return username, password, org, install_uuid
|
||||
|
||||
def _handle_register(self, options):
|
||||
dry_run = options['dry_run']
|
||||
force = options['force']
|
||||
|
||||
# Check whether a cert is already stored unless --force.
|
||||
existing_cert, existing_key, _ = _fetch_candlepin_cert_from_db()
|
||||
if existing_cert and existing_key and not force:
|
||||
self.stdout.write('A Candlepin identity certificate is already stored in database. Use --force to re-register and replace it.')
|
||||
return True
|
||||
|
||||
# Resolve credentials: CLI flags take precedence over database.
|
||||
resolved = self._resolve_and_validate_credentials(options)
|
||||
if resolved is None:
|
||||
return False
|
||||
username, password, org, db_install_uuid = resolved
|
||||
|
||||
candlepin_url = options.get('candlepin_url') or get_candlepin_url()
|
||||
candlepin_ca = options.get('candlepin_ca') or get_candlepin_ca()
|
||||
proxy = options.get('proxy') or get_proxy_url()
|
||||
verify_tls = not options.get('no_verify_tls', False)
|
||||
|
||||
# If dry-run, display what would happen and exit early before any Candlepin operations
|
||||
if dry_run:
|
||||
self.stdout.write('[dry-run] Would register with Candlepin:')
|
||||
self.stdout.write(f' URL : {candlepin_url}')
|
||||
self.stdout.write(f' Organization : {org}')
|
||||
self.stdout.write(f' Username : {username}')
|
||||
self.stdout.write(f' Install UUID : {db_install_uuid}')
|
||||
if candlepin_ca:
|
||||
self.stdout.write(f' CA cert : {candlepin_ca}')
|
||||
if proxy:
|
||||
self.stdout.write(f' Proxy : {proxy}')
|
||||
self.stdout.write(f' Verify TLS : {verify_tls}')
|
||||
self.stdout.write('[dry-run] No Candlepin operations performed.')
|
||||
return True
|
||||
|
||||
client = CandlepinClient(base_url=candlepin_url, candlepin_ca=candlepin_ca, proxy=proxy, verify_tls=verify_tls)
|
||||
|
||||
self.stdout.write(f'Registering with Candlepin at {candlepin_url} (org={org}) ...')
|
||||
try:
|
||||
cert_pem, key_pem, consumer_uuid = client.register_consumer(username, password, org, install_uuid=db_install_uuid)
|
||||
except Exception as e:
|
||||
self.stderr.write(f'Registration failed: {e}')
|
||||
return False
|
||||
|
||||
self.stdout.write('Registered successfully.')
|
||||
self.stdout.write(f' Consumer UUID : {consumer_uuid}')
|
||||
|
||||
# Save to database
|
||||
if _save_candlepin_registration_to_db(cert_pem, key_pem, consumer_uuid):
|
||||
self.stdout.write('Certificate, key, and consumer UUID saved to database.')
|
||||
else:
|
||||
self.stderr.write('Failed to save registration to database.')
|
||||
return False
|
||||
|
||||
# Best-effort certificate metadata display
|
||||
try:
|
||||
info = parse_cert(cert_pem)
|
||||
self.stdout.write(f' Cert serial : {info["serial"]}')
|
||||
self.stdout.write(f' Cert CN : {info["cn"]}')
|
||||
self.stdout.write(f' Valid until : {info["not_after"]} ({info["days_remaining"]} days remaining)')
|
||||
except ValueError as e:
|
||||
self.stdout.write(f'Certificate metadata unavailable: {e}')
|
||||
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# renew
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_renew(self, options):
|
||||
dry_run = options['dry_run']
|
||||
force = options['force']
|
||||
|
||||
cert_pem, key_pem, consumer_uuid = _fetch_candlepin_cert_from_db()
|
||||
|
||||
if not cert_pem or not key_pem:
|
||||
self.stderr.write('No Candlepin identity certificate found in database. Run the register subcommand first.')
|
||||
return False
|
||||
|
||||
if not consumer_uuid:
|
||||
self.stderr.write('CANDLEPIN_CONSUMER_UUID is not set. Run the register subcommand first.')
|
||||
return False
|
||||
|
||||
try:
|
||||
info = parse_cert(cert_pem)
|
||||
self.stdout.write('Current certificate:')
|
||||
self.stdout.write(f' Serial : {info["serial"]}')
|
||||
self.stdout.write(f' CN : {info["cn"]}')
|
||||
self.stdout.write(f' Valid until : {info["not_after"]} ({info["days_remaining"]} days remaining)')
|
||||
except ValueError as e:
|
||||
self.stdout.write('Current certificate:')
|
||||
self.stdout.write(f' Certificate metadata unavailable: {e}')
|
||||
info = None
|
||||
|
||||
candlepin_url = options.get('candlepin_url') or get_candlepin_url()
|
||||
candlepin_ca = options.get('candlepin_ca') or get_candlepin_ca()
|
||||
proxy = options.get('proxy') or get_proxy_url()
|
||||
verify_tls = not options.get('no_verify_tls', False)
|
||||
renewal_days = get_renewal_days()
|
||||
|
||||
# Check if renewal is needed (without force, just check cert expiry locally)
|
||||
renewal_needed = force or needs_renewal(cert_pem, renewal_days)
|
||||
|
||||
# If dry-run, display what would happen and exit early before any Candlepin operations
|
||||
if dry_run:
|
||||
self.stdout.write('[dry-run] Would perform the following operations:')
|
||||
self.stdout.write(f' URL : {candlepin_url}')
|
||||
self.stdout.write(f' Consumer UUID : {consumer_uuid}')
|
||||
if candlepin_ca:
|
||||
self.stdout.write(f' CA cert : {candlepin_ca}')
|
||||
if proxy:
|
||||
self.stdout.write(f' Proxy : {proxy}')
|
||||
self.stdout.write(f' Verify TLS : {verify_tls}')
|
||||
self.stdout.write(' 1. Check in with Candlepin')
|
||||
if renewal_needed:
|
||||
reason = 'forced via --force' if force else f'expiry within {renewal_days} days'
|
||||
self.stdout.write(f' 2. Renew certificate ({reason})')
|
||||
else:
|
||||
if info:
|
||||
self.stdout.write(f' 2. No renewal needed ({info["days_remaining"]} days remaining, threshold: {renewal_days} days)')
|
||||
else:
|
||||
self.stdout.write(f' 2. No renewal needed (threshold: {renewal_days} days)')
|
||||
self.stdout.write('[dry-run] No Candlepin operations performed.')
|
||||
return True
|
||||
|
||||
client = CandlepinClient(base_url=candlepin_url, candlepin_ca=candlepin_ca, proxy=proxy, verify_tls=verify_tls)
|
||||
|
||||
self.stdout.write(f'Checking in with Candlepin at {candlepin_url} (consumer={consumer_uuid}) ...')
|
||||
checkin_success = client.checkin(consumer_uuid, cert_pem, key_pem)
|
||||
|
||||
if not checkin_success:
|
||||
self.stderr.write('Check-in with Candlepin failed. Unable to verify certificate status.')
|
||||
self.stderr.write('Certificate renewal may still be needed. Use --force to renew anyway, or check logs for details.')
|
||||
return False
|
||||
|
||||
self.stdout.write('Check-in successful.')
|
||||
|
||||
if not renewal_needed:
|
||||
if info:
|
||||
self.stdout.write(f'Certificate has {info["days_remaining"]} days remaining (renewal threshold: {renewal_days} days). No renewal needed.')
|
||||
else:
|
||||
self.stdout.write(f'Certificate renewal threshold is {renewal_days} days. No renewal needed.')
|
||||
return True
|
||||
|
||||
reason = 'forced via --force' if force else f'expiry within {renewal_days} days'
|
||||
self.stdout.write(f'Renewing certificate ({reason}) ...')
|
||||
try:
|
||||
new_cert_pem, new_key_pem = client.regenerate_cert(consumer_uuid, cert_pem, key_pem)
|
||||
except Exception as e:
|
||||
self.stderr.write(f'Certificate renewal failed: {e}')
|
||||
return False
|
||||
|
||||
self.stdout.write('Certificate renewed successfully.')
|
||||
|
||||
# Save to database
|
||||
if _save_candlepin_cert_to_db(new_cert_pem, new_key_pem):
|
||||
self.stdout.write('Renewed certificate and key saved to database.')
|
||||
else:
|
||||
self.stderr.write('Failed to save renewed certificate to database.')
|
||||
return False
|
||||
|
||||
# Best-effort certificate metadata display
|
||||
try:
|
||||
new_info = parse_cert(new_cert_pem)
|
||||
if info:
|
||||
self.stdout.write(f' Old serial : {info["serial"]}')
|
||||
self.stdout.write(f' New serial : {new_info["serial"]}')
|
||||
self.stdout.write(f' Valid until : {new_info["not_after"]} ({new_info["days_remaining"]} days remaining)')
|
||||
except ValueError as e:
|
||||
self.stdout.write(f'Certificate metadata unavailable: {e}')
|
||||
|
||||
return True
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) 2015 Ansible, Inc.
|
||||
# All Rights Reserved
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.db import connection
|
||||
|
||||
from awx.main.utils.db import db_requirement_violations
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Checks connection to the database, and prints out connection info if not connected"""
|
||||
@@ -13,4 +15,8 @@ class Command(BaseCommand):
|
||||
cursor.execute("SELECT version()")
|
||||
version = str(cursor.fetchone()[0])
|
||||
|
||||
violations = db_requirement_violations()
|
||||
if violations:
|
||||
raise CommandError(violations)
|
||||
|
||||
return "Database Version: {}".format(version)
|
||||
|
||||
@@ -52,7 +52,11 @@ class Command(BaseCommand):
|
||||
|
||||
ssh_type = CredentialType.objects.filter(namespace='ssh').first()
|
||||
c, _ = Credential.objects.get_or_create(
|
||||
credential_type=ssh_type, name='Demo Credential', inputs={'username': getattr(superuser, 'username', 'null')}, created_by=superuser
|
||||
credential_type=ssh_type,
|
||||
name='Demo Credential',
|
||||
inputs={'username': getattr(superuser, 'username', 'null')},
|
||||
created_by=superuser,
|
||||
organization=o,
|
||||
)
|
||||
|
||||
if superuser:
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
import uuid
|
||||
from django.db import models
|
||||
from django.conf import settings
|
||||
from django.db.models import OuterRef, Subquery
|
||||
from django.db.models.functions import Lower
|
||||
|
||||
from ansible_base.lib.utils.db import advisory_lock
|
||||
@@ -23,7 +24,65 @@ class DeferJobCreatedManager(models.Manager):
|
||||
return super(DeferJobCreatedManager, self).get_queryset().defer('job_created')
|
||||
|
||||
|
||||
class HostManager(models.Manager):
|
||||
class HostLatestSummaryQuerySet(models.QuerySet):
|
||||
"""Queryset that annotates and bulk-attaches the latest JobHostSummary
|
||||
at queryset evaluation time, similar to prefetch_related().
|
||||
|
||||
Why not use Django's Prefetch?
|
||||
Django's Prefetch with [:1] slicing fetches 1 record globally, not per-host
|
||||
(Django ticket #26780). Window-function workarounds require Django 4.2+ and
|
||||
are more complex. Prefetching all summaries then filtering in Python wastes
|
||||
memory for hosts with many job runs. The approach here — annotate the latest
|
||||
ID via Subquery, then in_bulk() only those IDs — is the same 2-query pattern
|
||||
prefetch_related uses internally, customized for "latest per group."
|
||||
|
||||
Not streaming-safe: relies on _result_cache existing after _fetch_all().
|
||||
"""
|
||||
|
||||
_awx_latest_summary_attached = False
|
||||
|
||||
def _clone(self):
|
||||
clone = super()._clone()
|
||||
clone._awx_latest_summary_attached = self._awx_latest_summary_attached
|
||||
return clone
|
||||
|
||||
def with_latest_summary_id(self):
|
||||
from awx.main.models.jobs import JobHostSummary
|
||||
|
||||
latest_summary = JobHostSummary.objects.filter(host_id=OuterRef('pk')).order_by('-id')
|
||||
return self.annotate(
|
||||
_latest_summary_id=Subquery(latest_summary.values('id')[:1]),
|
||||
)
|
||||
|
||||
def _fetch_all(self):
|
||||
super()._fetch_all()
|
||||
|
||||
if self._awx_latest_summary_attached or not self._result_cache:
|
||||
return
|
||||
|
||||
# Only bulk-attach if the queryset was annotated via with_latest_summary_id().
|
||||
# Without this guard, we'd set _latest_summary_cache=None on every host,
|
||||
# masking the per-object fallback query in Host.latest_summary.
|
||||
if not hasattr(self._result_cache[0], '_latest_summary_id'):
|
||||
return
|
||||
|
||||
from awx.main.models.jobs import JobHostSummary
|
||||
|
||||
latest_summary_ids = [host._latest_summary_id for host in self._result_cache if host._latest_summary_id is not None]
|
||||
|
||||
if latest_summary_ids:
|
||||
summaries_by_id = JobHostSummary.objects.select_related('job', 'job__job_template').in_bulk(latest_summary_ids)
|
||||
else:
|
||||
summaries_by_id = {}
|
||||
|
||||
for host in self._result_cache:
|
||||
latest_summary_id = getattr(host, '_latest_summary_id', None)
|
||||
host._latest_summary_cache = summaries_by_id.get(latest_summary_id)
|
||||
|
||||
self._awx_latest_summary_attached = True
|
||||
|
||||
|
||||
class HostManager(models.Manager.from_queryset(HostLatestSummaryQuerySet)):
|
||||
"""Custom manager class for Hosts model."""
|
||||
|
||||
def active_count(self):
|
||||
@@ -31,38 +90,46 @@ class HostManager(models.Manager):
|
||||
Construction of query involves:
|
||||
- remove any ordering specified in model's Meta
|
||||
- Exclude hosts sourced from another Tower
|
||||
- Exclude hosts in constructed inventories (these are shadow rows of source-inventory hosts)
|
||||
- Restrict the query to only return the name column
|
||||
- Only consider results that are unique
|
||||
- Return the count of this query
|
||||
"""
|
||||
return self.order_by().exclude(inventory_sources__source='controller').values(name_lower=Lower('name')).distinct().count()
|
||||
return (
|
||||
self.order_by()
|
||||
.exclude(inventory_sources__source='controller')
|
||||
.exclude(inventory__kind='constructed')
|
||||
.values(name_lower=Lower('name'))
|
||||
.distinct()
|
||||
.count()
|
||||
)
|
||||
|
||||
def org_active_count(self, org_id):
|
||||
"""Return count of active, unique hosts used by an organization.
|
||||
Construction of query involves:
|
||||
- remove any ordering specified in model's Meta
|
||||
- Exclude hosts sourced from another Tower
|
||||
- Exclude hosts in constructed inventories (these are shadow rows of source-inventory hosts)
|
||||
- Consider only hosts where the canonical inventory is owned by the organization
|
||||
- Restrict the query to only return the name column
|
||||
- Only consider results that are unique
|
||||
- Return the count of this query
|
||||
"""
|
||||
return self.order_by().exclude(inventory_sources__source='controller').filter(inventory__organization=org_id).values('name').distinct().count()
|
||||
return (
|
||||
self.order_by()
|
||||
.exclude(inventory_sources__source='controller')
|
||||
.exclude(inventory__kind='constructed')
|
||||
.filter(inventory__organization=org_id)
|
||||
.values('name')
|
||||
.distinct()
|
||||
.count()
|
||||
)
|
||||
|
||||
def get_queryset(self):
|
||||
"""When the parent instance of the host query set has a `kind=smart` and a `host_filter`
|
||||
set. Use the `host_filter` to generate the queryset for the hosts.
|
||||
"""
|
||||
qs = (
|
||||
super(HostManager, self)
|
||||
.get_queryset()
|
||||
.defer(
|
||||
'last_job__extra_vars',
|
||||
'last_job_host_summary__job__extra_vars',
|
||||
'last_job__artifacts',
|
||||
'last_job_host_summary__job__artifacts',
|
||||
)
|
||||
)
|
||||
qs = super().get_queryset().defer('ansible_facts')
|
||||
|
||||
if hasattr(self, 'instance') and hasattr(self.instance, 'host_filter') and hasattr(self.instance, 'kind'):
|
||||
if self.instance.kind == 'smart' and self.instance.host_filter is not None:
|
||||
|
||||
@@ -211,7 +211,7 @@ class AdHocCommand(UnifiedJob, JobNotificationMixin):
|
||||
return AdHocCommand.objects.create(**data)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
|
||||
def add_to_update_fields(name):
|
||||
if name not in update_fields:
|
||||
|
||||
@@ -177,7 +177,7 @@ class CreatedModifiedModel(BaseModel):
|
||||
)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
update_fields = list(kwargs.get('update_fields', []))
|
||||
update_fields = list(kwargs.get('update_fields') or [])
|
||||
# Manually perform auto_now_add and auto_now logic.
|
||||
if not self.pk and not self.created:
|
||||
self.created = now()
|
||||
@@ -207,7 +207,7 @@ class PasswordFieldsModel(BaseModel):
|
||||
new_instance = not bool(self.pk)
|
||||
# If update_fields has been specified, add our field names to it,
|
||||
# if it hasn't been specified, then we're just doing a normal save.
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
# When first saving to the database, don't store any password field
|
||||
# values, but instead save them until after the instance is created.
|
||||
# Otherwise, store encrypted values to the database.
|
||||
@@ -322,7 +322,7 @@ class PrimordialModel(HasEditsMixin, CreatedModifiedModel):
|
||||
self._prior_values_store = {}
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
user = get_current_user()
|
||||
if user and not user.id:
|
||||
user = None
|
||||
|
||||
@@ -47,12 +47,9 @@ from awx.main.models.rbac import (
|
||||
)
|
||||
from awx.main.models import Team, Organization
|
||||
from awx.main.utils import encrypt_field
|
||||
from awx.main.utils.lazy_registry import LazyLoadDict
|
||||
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name
|
||||
|
||||
# DAB
|
||||
from ansible_base.resource_registry.tasks.sync import get_resource_server_client
|
||||
from ansible_base.resource_registry.utils.settings import resource_server_defined
|
||||
|
||||
__all__ = ['Credential', 'CredentialType', 'CredentialInputSource', 'build_safe_env']
|
||||
|
||||
logger = logging.getLogger('awx.main.models.credential')
|
||||
@@ -80,46 +77,6 @@ def build_safe_env(env):
|
||||
return safe_env
|
||||
|
||||
|
||||
def check_resource_server_for_user_in_organization(user, organization, requesting_user):
|
||||
if not resource_server_defined():
|
||||
return False
|
||||
|
||||
if not requesting_user:
|
||||
return False
|
||||
|
||||
client = get_resource_server_client(settings.RESOURCE_SERVICE_PATH, jwt_user_id=str(requesting_user.resource.ansible_id), raise_if_bad_request=False)
|
||||
# need to get the organization object_id in resource server, by querying with ansible_id
|
||||
response = client._make_request(path=f'resources/?ansible_id={str(organization.resource.ansible_id)}', method='GET')
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
logger.error(f'Failed to get organization object_id in resource server: {response_json.get("detail", "")}')
|
||||
return False
|
||||
|
||||
if response_json.get('count', 0) == 0:
|
||||
return False
|
||||
org_id_in_resource_server = response_json['results'][0]['object_id']
|
||||
|
||||
client.base_url = client.base_url.replace('/api/gateway/v1/service-index/', '/api/gateway/v1/')
|
||||
# find role assignments with:
|
||||
# - roles Organization Member or Organization Admin
|
||||
# - user ansible id
|
||||
# - organization object id
|
||||
|
||||
response = client._make_request(
|
||||
path=f'role_user_assignments/?role_definition__name__in=Organization Member,Organization Admin&user__resource__ansible_id={str(user.resource.ansible_id)}&object_id={org_id_in_resource_server}',
|
||||
method='GET',
|
||||
)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
logger.error(f'Failed to get role user assignments in resource server: {response_json.get("detail", "")}')
|
||||
return False
|
||||
|
||||
if response_json.get('count', 0) > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
|
||||
"""
|
||||
A credential contains information about how to talk to a remote resource
|
||||
@@ -396,16 +353,15 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
|
||||
raise ValueError('{} is not a dynamic input field'.format(field_name))
|
||||
|
||||
def validate_role_assignment(self, actor, role_definition, **kwargs):
|
||||
requesting_user = kwargs.get('requesting_user', None)
|
||||
if requesting_user and requesting_user.is_superuser:
|
||||
return
|
||||
if self.organization:
|
||||
if isinstance(actor, User):
|
||||
if actor.is_superuser:
|
||||
return
|
||||
if Organization.access_qs(actor, 'member').filter(id=self.organization.id).exists():
|
||||
return
|
||||
|
||||
requesting_user = kwargs.get('requesting_user', None)
|
||||
if check_resource_server_for_user_in_organization(actor, self.organization, requesting_user):
|
||||
return
|
||||
if isinstance(actor, Team):
|
||||
if actor.organization == self.organization:
|
||||
return
|
||||
@@ -614,7 +570,7 @@ class CredentialTypeHelper:
|
||||
|
||||
|
||||
class ManagedCredentialType(SimpleNamespace):
|
||||
registry = {}
|
||||
registry = None # initialized as LazyLoadDict after load_credentials is defined
|
||||
|
||||
|
||||
class CredentialInputSource(PrimordialModel):
|
||||
@@ -706,6 +662,8 @@ def _is_oidc_namespace_disabled(ns):
|
||||
|
||||
|
||||
def load_credentials():
|
||||
ManagedCredentialType.registry.clear()
|
||||
|
||||
awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')}
|
||||
supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')}
|
||||
plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points}
|
||||
@@ -737,3 +695,8 @@ def load_credentials():
|
||||
|
||||
plugin = ep.load()
|
||||
CredentialType.load_plugin(ns, plugin)
|
||||
|
||||
|
||||
# load_credentials writes directly into this dict via registry[ns] = ...,
|
||||
# LazyLoadDict just ensures it runs once before the first read access
|
||||
ManagedCredentialType.registry = LazyLoadDict(load_credentials)
|
||||
|
||||
@@ -24,7 +24,6 @@ from awx.main.managers import DeferJobCreatedManager
|
||||
from awx.main.constants import MINIMAL_EVENTS
|
||||
from awx.main.models.base import CreatedModifiedModel
|
||||
from awx.main.utils import ignore_inventory_computed_fields, camelcase_to_underscore
|
||||
from awx.main.utils.db import bulk_update_sorted_by_id
|
||||
|
||||
analytics_logger = logging.getLogger('awx.analytics.job_events')
|
||||
|
||||
@@ -590,20 +589,8 @@ class JobEvent(BasePlaybookEvent):
|
||||
|
||||
JobHostSummary.objects.bulk_create(summaries.values())
|
||||
|
||||
# update the last_job_id and last_job_host_summary_id
|
||||
# in single queries
|
||||
host_mapping = dict((summary['host_id'], summary['id']) for summary in JobHostSummary.objects.filter(job_id=job.id).values('id', 'host_id'))
|
||||
updated_hosts = set()
|
||||
for h in all_hosts:
|
||||
# if the hostname *shows up* in the playbook_on_stats event
|
||||
if h.name in hostnames:
|
||||
h.last_job_id = job.id
|
||||
updated_hosts.add(h)
|
||||
if h.id in host_mapping:
|
||||
h.last_job_host_summary_id = host_mapping[h.id]
|
||||
updated_hosts.add(h)
|
||||
|
||||
bulk_update_sorted_by_id(Host, updated_hosts, ['last_job_id', 'last_job_host_summary_id'])
|
||||
# last_job and last_job_host_summary are now derived via
|
||||
# JobHostSummary.latest_for_host / latest_job_for_host
|
||||
|
||||
# Create/update Host Metrics
|
||||
self._update_host_metrics(updated_hosts_list)
|
||||
|
||||
@@ -58,8 +58,6 @@ class ExecutionEnvironment(CommonModel):
|
||||
return reverse('api:execution_environment_detail', kwargs={'pk': self.pk}, request=request)
|
||||
|
||||
def validate_role_assignment(self, actor, role_definition, **kwargs):
|
||||
from awx.main.models.credential import check_resource_server_for_user_in_organization
|
||||
|
||||
if self.managed:
|
||||
raise ValidationError({'object_id': _('Can not assign object roles to managed Execution Environments')})
|
||||
if self.organization_id is None:
|
||||
@@ -69,8 +67,4 @@ class ExecutionEnvironment(CommonModel):
|
||||
if actor.has_obj_perm(self.organization, 'view'):
|
||||
return
|
||||
|
||||
requesting_user = kwargs.get('requesting_user', None)
|
||||
if check_resource_server_for_user_in_organization(actor, self.organization, requesting_user):
|
||||
return
|
||||
|
||||
raise ValidationError({'user': _('User must have view permission to Execution Environment organization')})
|
||||
|
||||
@@ -18,7 +18,7 @@ from django.db import transaction
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.urls import resolve
|
||||
from django.utils.timezone import now
|
||||
from django.db.models import Q
|
||||
from django.db.models import Q, Subquery, OuterRef
|
||||
|
||||
# REST Framework
|
||||
from rest_framework.exceptions import ParseError
|
||||
@@ -27,7 +27,10 @@ from ansible_base.lib.utils.models import prevent_search
|
||||
|
||||
# AWX
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.utils.common import load_all_entry_points_for
|
||||
from awx.main.utils.lazy_registry import LazyLoadDict
|
||||
from awx.main.utils.plugins import discover_available_cloud_provider_plugin_names, compute_cloud_inventory_sources
|
||||
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name
|
||||
from awx.main.consumers import emit_channel_notification
|
||||
from awx.main.fields import (
|
||||
ImplicitRoleField,
|
||||
@@ -386,7 +389,10 @@ class Inventory(CommonModelNameNotUnique, ResourceMixin, RelatedJobsMixin, OpaQu
|
||||
logger.debug("Going to update inventory computed fields, pk={0}".format(self.pk))
|
||||
start_time = time.time()
|
||||
active_hosts = self.hosts
|
||||
failed_hosts = active_hosts.filter(last_job_host_summary__failed=True)
|
||||
from awx.main.models.jobs import JobHostSummary # circular import: inventory.py loads before jobs.py
|
||||
|
||||
latest_summary_failed = Subquery(JobHostSummary.objects.filter(host_id=OuterRef('pk')).order_by('-id').values('failed')[:1])
|
||||
failed_hosts = active_hosts.annotate(_latest_failed=latest_summary_failed).filter(_latest_failed=True)
|
||||
active_groups = self.groups
|
||||
if self.kind == 'smart':
|
||||
active_groups = active_groups.none()
|
||||
@@ -582,6 +588,23 @@ class Host(CommonModelNameNotUnique, RelatedJobsMixin):
|
||||
|
||||
objects = HostManager()
|
||||
|
||||
@property
|
||||
def latest_summary(self):
|
||||
if hasattr(self, '_latest_summary_cache'):
|
||||
return self._latest_summary_cache
|
||||
from awx.main.models.jobs import JobHostSummary
|
||||
|
||||
summary = JobHostSummary.objects.filter(host_id=self.pk).order_by('-id').select_related('job', 'job__job_template').first()
|
||||
self._latest_summary_cache = summary
|
||||
return summary
|
||||
|
||||
@property
|
||||
def latest_job(self):
|
||||
summary = self.latest_summary
|
||||
if summary is None:
|
||||
return None
|
||||
return summary.job
|
||||
|
||||
def get_absolute_url(self, request=None):
|
||||
return reverse('api:host_detail', kwargs={'pk': self.pk}, request=request)
|
||||
|
||||
@@ -906,12 +929,22 @@ class HostMetricSummaryMonthly(models.Model):
|
||||
indirectly_managed_hosts = models.IntegerField(default=0, help_text=("Manually entered number indirectly managed hosts for a certain month"))
|
||||
|
||||
|
||||
def _load_inventory_plugins():
|
||||
is_awx = detect_server_product_name() == 'AWX'
|
||||
extra_entry_point_groups = () if is_awx else ('inventory.supported',)
|
||||
all_entry_points = load_all_entry_points_for(['inventory', *extra_entry_point_groups])
|
||||
|
||||
for entry_point_name, entry_point in all_entry_points.items():
|
||||
cls = entry_point.load()
|
||||
InventorySourceOptions.injectors[entry_point_name] = cls
|
||||
|
||||
|
||||
class InventorySourceOptions(BaseModel):
|
||||
"""
|
||||
Common fields for InventorySource and InventoryUpdate.
|
||||
"""
|
||||
|
||||
injectors = dict()
|
||||
injectors = LazyLoadDict(_load_inventory_plugins)
|
||||
|
||||
# From the options of the Django management base command
|
||||
INVENTORY_UPDATE_VERBOSITY_CHOICES = [
|
||||
@@ -1129,7 +1162,7 @@ class InventorySource(UnifiedJobTemplate, InventorySourceOptions, CustomVirtualE
|
||||
|
||||
# If update_fields has been specified, add our field names to it,
|
||||
# if it hasn't been specified, then we're just doing a normal save.
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
is_new_instance = not bool(self.pk)
|
||||
|
||||
# Set name automatically. Include PK (or placeholder) to make sure the names are always unique.
|
||||
|
||||
@@ -52,7 +52,7 @@ from awx.main.models.mixins import (
|
||||
WebhookTemplateMixin,
|
||||
OpaQueryPathMixin,
|
||||
)
|
||||
from awx.main.constants import JOB_VARIABLE_PREFIXES
|
||||
from awx.main.utils.common import get_job_variable_prefixes
|
||||
|
||||
logger = logging.getLogger('awx.main.models.jobs')
|
||||
|
||||
@@ -347,7 +347,7 @@ class JobTemplate(
|
||||
return actual_slice_count
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
# if project is deleted for some reason, then keep the old organization
|
||||
# to retain ownership for organization admins
|
||||
if self.project and self.project.organization_id != self.organization_id:
|
||||
@@ -817,19 +817,20 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
|
||||
|
||||
def awx_meta_vars(self):
|
||||
r = super(Job, self).awx_meta_vars()
|
||||
prefixes = get_job_variable_prefixes()
|
||||
if self.project:
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_project_revision'.format(name)] = self.project.scm_revision
|
||||
r['{}_project_scm_branch'.format(name)] = self.project.scm_branch
|
||||
if self.scm_branch:
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_job_scm_branch'.format(name)] = self.scm_branch
|
||||
if self.job_template:
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_job_template_id'.format(name)] = self.job_template.pk
|
||||
r['{}_job_template_name'.format(name)] = self.job_template.name
|
||||
if self.execution_node:
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_execution_node'.format(name)] = self.execution_node
|
||||
return r
|
||||
|
||||
@@ -1140,6 +1141,22 @@ class JobHostSummary(CreatedModifiedModel):
|
||||
self.skipped,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def latest_for_host(cls, host_id):
|
||||
"""Return the most recent JobHostSummary for a given host, or None."""
|
||||
return cls.objects.filter(host_id=host_id).order_by('-id').first()
|
||||
|
||||
@classmethod
|
||||
def latest_job_for_host(cls, host_id):
|
||||
"""Return the Job from the most recent JobHostSummary for a host, or None."""
|
||||
summary = cls.latest_for_host(host_id)
|
||||
if summary:
|
||||
try:
|
||||
return summary.job
|
||||
except cls.job.field.related_model.DoesNotExist:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_absolute_url(self, request=None):
|
||||
return reverse('api:job_host_summary_detail', kwargs={'pk': self.pk}, request=request)
|
||||
|
||||
@@ -1148,7 +1165,7 @@ class JobHostSummary(CreatedModifiedModel):
|
||||
# if it hasn't been specified, then we're just doing a normal save.
|
||||
if self.host is not None:
|
||||
self.host_name = self.host.name
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
self.failed = bool(self.dark or self.failures)
|
||||
update_fields.append('failed')
|
||||
super(JobHostSummary, self).save(*args, **kwargs)
|
||||
|
||||
@@ -99,7 +99,7 @@ class NotificationTemplate(CommonModelNameNotUnique):
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
new_instance = not bool(self.pk)
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
|
||||
# preserve existing notification messages if not overwritten by new messages
|
||||
if not new_instance:
|
||||
|
||||
@@ -367,7 +367,7 @@ class Project(UnifiedJobTemplate, ProjectOptions, ResourceMixin, CustomVirtualEn
|
||||
pre_save_vals = getattr(self, '_prior_values_store', {})
|
||||
# If update_fields has been specified, add our field names to it,
|
||||
# if it hasn't been specified, then we're just doing a normal save.
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
self._skip_update = bool(kwargs.pop('skip_update', False))
|
||||
# Create auto-generated local path if project uses SCM.
|
||||
if self.pk and self.scm_type and not self.local_path.startswith('_'):
|
||||
|
||||
@@ -613,7 +613,7 @@ def get_role_from_object_role(object_role):
|
||||
model_name, role_name = rd.name.split()
|
||||
role_name = role_name.lower()
|
||||
role_name += '_role'
|
||||
return getattr(object_role.content_object, role_name)
|
||||
return getattr(object_role.content_object, role_name, None)
|
||||
|
||||
|
||||
def give_or_remove_permission(role, actor, giving=True, rd=None):
|
||||
@@ -649,6 +649,8 @@ def give_creator_permissions(user, obj):
|
||||
if assignment:
|
||||
with disable_rbac_sync():
|
||||
old_role = get_role_from_object_role(assignment.object_role)
|
||||
if old_role is None:
|
||||
return
|
||||
old_role.members.add(user)
|
||||
|
||||
|
||||
|
||||
@@ -72,10 +72,10 @@ def _fast_forward_rrule(rrule, ref_dt=None):
|
||||
if ref_dt is None:
|
||||
ref_dt = now()
|
||||
|
||||
ref_dt = ref_dt.astimezone(datetime.timezone.utc)
|
||||
dtstart_tz = rrule._dtstart.tzinfo
|
||||
ref_dt = ref_dt.astimezone(dtstart_tz)
|
||||
|
||||
rrule_dtstart_utc = rrule._dtstart.astimezone(datetime.timezone.utc)
|
||||
if rrule_dtstart_utc > ref_dt:
|
||||
if rrule._dtstart > ref_dt:
|
||||
return rrule
|
||||
|
||||
interval = rrule._interval if rrule._interval else 1
|
||||
@@ -84,20 +84,14 @@ def _fast_forward_rrule(rrule, ref_dt=None):
|
||||
elif rrule._freq == dateutil.rrule.MINUTELY:
|
||||
interval *= 60
|
||||
|
||||
# if after converting to seconds the interval is still a fraction,
|
||||
# just return original rrule
|
||||
if isinstance(interval, float) and not interval.is_integer():
|
||||
return rrule
|
||||
|
||||
seconds_since_dtstart = (ref_dt - rrule_dtstart_utc).total_seconds()
|
||||
seconds_since_dtstart = (ref_dt - rrule._dtstart).total_seconds()
|
||||
|
||||
# it is important to fast forward by a number that is divisible by
|
||||
# interval. For example, if interval is 7 hours, we fast forward by 7, 14, 21, etc. hours.
|
||||
# Otherwise, the occurrences after the fast forward might not match the ones before.
|
||||
# x // y is integer division, lopping off any remainder, so that we get the outcome we want.
|
||||
interval_aligned_offset = datetime.timedelta(seconds=(seconds_since_dtstart // interval) * interval)
|
||||
new_start = rrule_dtstart_utc + interval_aligned_offset
|
||||
new_rrule = rrule.replace(dtstart=new_start.astimezone(rrule._dtstart.tzinfo))
|
||||
new_start = rrule._dtstart + interval_aligned_offset
|
||||
new_rrule = rrule.replace(dtstart=new_start)
|
||||
return new_rrule
|
||||
|
||||
|
||||
|
||||
@@ -58,7 +58,8 @@ from awx.main.utils.common import (
|
||||
)
|
||||
from awx.main.utils.encryption import encrypt_dict, decrypt_field
|
||||
from awx.main.utils import polymorphic
|
||||
from awx.main.constants import ACTIVE_STATES, CAN_CANCEL, JOB_VARIABLE_PREFIXES
|
||||
from awx.main.constants import ACTIVE_STATES, CAN_CANCEL
|
||||
from awx.main.utils.common import get_job_variable_prefixes
|
||||
from awx.main.redact import UriCleaner, REPLACE_STR
|
||||
from awx.main.consumers import emit_channel_notification
|
||||
from awx.main.fields import AskForField, OrderedManyToManyField
|
||||
@@ -304,7 +305,7 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, ExecutionEn
|
||||
def save(self, *args, **kwargs):
|
||||
# If update_fields has been specified, add our field names to it,
|
||||
# if it hasn't been specified, then we're just doing a normal save.
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
# Update status and last_updated fields.
|
||||
if not getattr(_inventory_updates, 'is_updating', False):
|
||||
updated_fields = self._set_status_and_last_job_run(save=False)
|
||||
@@ -876,7 +877,7 @@ class UnifiedJob(
|
||||
"""
|
||||
# If update_fields has been specified, add our field names to it,
|
||||
# if it hasn't been specified, then we're just doing a normal save.
|
||||
update_fields = kwargs.get('update_fields', [])
|
||||
update_fields = kwargs.get('update_fields') or []
|
||||
|
||||
# Get status before save...
|
||||
status_before = self.status or 'new'
|
||||
@@ -1568,7 +1569,8 @@ class UnifiedJob(
|
||||
by AWX, for purposes of client playbook hooks
|
||||
"""
|
||||
r = {}
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
prefixes = get_job_variable_prefixes()
|
||||
for name in prefixes:
|
||||
r['{}_job_id'.format(name)] = self.pk
|
||||
r['{}_job_launch_type'.format(name)] = self.launch_type
|
||||
|
||||
@@ -1577,7 +1579,7 @@ class UnifiedJob(
|
||||
wj = self.get_workflow_job()
|
||||
if wj:
|
||||
schedule = getattr_dne(wj, 'schedule')
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_workflow_job_id'.format(name)] = wj.pk
|
||||
r['{}_workflow_job_name'.format(name)] = wj.name
|
||||
r['{}_workflow_job_launch_type'.format(name)] = wj.launch_type
|
||||
@@ -1588,12 +1590,12 @@ class UnifiedJob(
|
||||
if not created_by:
|
||||
schedule = getattr_dne(self, 'schedule')
|
||||
if schedule:
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_schedule_id'.format(name)] = schedule.pk
|
||||
r['{}_schedule_name'.format(name)] = schedule.name
|
||||
|
||||
if created_by:
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_user_id'.format(name)] = created_by.pk
|
||||
r['{}_user_name'.format(name)] = created_by.username
|
||||
r['{}_user_email'.format(name)] = created_by.email
|
||||
@@ -1602,7 +1604,7 @@ class UnifiedJob(
|
||||
|
||||
inventory = getattr_dne(self, 'inventory')
|
||||
if inventory:
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in prefixes:
|
||||
r['{}_inventory_id'.format(name)] = inventory.pk
|
||||
r['{}_inventory_name'.format(name)] = inventory.name
|
||||
|
||||
|
||||
@@ -335,9 +335,7 @@ class WorkflowJobNode(WorkflowNodeBase):
|
||||
# or labels, because they do not propogate WFJT-->node at all
|
||||
|
||||
# Combine WFJT prompts with node here, WFJT at higher level
|
||||
# Empty string values on the workflow job (e.g. from IaC setting limit: "")
|
||||
# should not override a node's explicit non-empty prompt value
|
||||
node_prompts_data.update({k: v for k, v in wj_prompts_data.items() if v != ''})
|
||||
node_prompts_data.update(wj_prompts_data)
|
||||
accepted_fields, ignored_fields, errors = ujt_obj._accept_or_ignore_job_kwargs(**node_prompts_data)
|
||||
if errors:
|
||||
logger.info(
|
||||
@@ -347,7 +345,11 @@ class WorkflowJobNode(WorkflowNodeBase):
|
||||
)
|
||||
data.update(accepted_fields) # missing fields are handled in the scheduler
|
||||
# build ancestor artifacts, save them to node model for later
|
||||
aa_dict = {}
|
||||
# initialize from pre-seeded ancestor_artifacts (set on root nodes of
|
||||
# child workflows via seed_root_ancestor_artifacts to carry artifacts
|
||||
# from the parent workflow); exclude job_slice which is internal
|
||||
# metadata handled separately below
|
||||
aa_dict = {k: v for k, v in self.ancestor_artifacts.items() if k != 'job_slice'} if self.ancestor_artifacts else {}
|
||||
is_root_node = True
|
||||
for parent_node in self.get_parent_nodes():
|
||||
is_root_node = False
|
||||
@@ -368,11 +370,13 @@ class WorkflowJobNode(WorkflowNodeBase):
|
||||
data['survey_passwords'] = password_dict
|
||||
# process extra_vars
|
||||
extra_vars = data.get('extra_vars', {})
|
||||
if ujt_obj and isinstance(ujt_obj, (JobTemplate, WorkflowJobTemplate)):
|
||||
if ujt_obj and isinstance(ujt_obj, JobTemplate):
|
||||
if aa_dict:
|
||||
functional_aa_dict = copy(aa_dict)
|
||||
functional_aa_dict.pop('_ansible_no_log', None)
|
||||
extra_vars.update(functional_aa_dict)
|
||||
elif ujt_obj and isinstance(ujt_obj, WorkflowJobTemplate):
|
||||
pass # artifacts are applied via seed_root_ancestor_artifacts in the task manager
|
||||
|
||||
# Workflow Job extra_vars higher precedence than ancestor artifacts
|
||||
extra_vars.update(wj_special_vars)
|
||||
@@ -736,6 +740,18 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio
|
||||
wj = wj.get_workflow_job()
|
||||
return ancestors
|
||||
|
||||
def seed_root_ancestor_artifacts(self, artifacts):
|
||||
"""Apply parent workflow artifacts to root nodes so they propagate
|
||||
through the normal ancestor_artifacts channel instead of being
|
||||
baked into this workflow's extra_vars."""
|
||||
self.workflow_job_nodes.exclude(
|
||||
workflowjobnodes_success__isnull=False,
|
||||
).exclude(
|
||||
workflowjobnodes_failure__isnull=False,
|
||||
).exclude(
|
||||
workflowjobnodes_always__isnull=False,
|
||||
).update(ancestor_artifacts=artifacts)
|
||||
|
||||
def get_effective_artifacts(self, **kwargs):
|
||||
"""
|
||||
For downstream jobs of a workflow nested inside of a workflow,
|
||||
@@ -884,7 +900,7 @@ class WorkflowApproval(UnifiedJob, JobNotificationMixin):
|
||||
return 'workflow_approval_template'
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
update_fields = list(kwargs.get('update_fields', []))
|
||||
update_fields = list(kwargs.get('update_fields') or [])
|
||||
if self.timeout != 0 and ((not self.pk) or (not update_fields) or ('timeout' in update_fields)):
|
||||
if not self.created: # on creation, created will be set by parent class, so we fudge it here
|
||||
created = now()
|
||||
|
||||
@@ -241,6 +241,8 @@ class WorkflowManager(TaskBase):
|
||||
job = spawn_node.unified_job_template.create_unified_job(**kv)
|
||||
spawn_node.job = job
|
||||
spawn_node.save()
|
||||
if spawn_node.ancestor_artifacts and isinstance(spawn_node.unified_job_template, WorkflowJobTemplate):
|
||||
job.seed_root_ancestor_artifacts(spawn_node.ancestor_artifacts)
|
||||
logger.debug('Spawned %s in %s for node %s', job.log_format, workflow_job.log_format, spawn_node.pk)
|
||||
can_start = True
|
||||
if isinstance(spawn_node.unified_job_template, WorkflowJobTemplate):
|
||||
@@ -686,6 +688,17 @@ class TaskManager(TaskBase):
|
||||
logger.error(f'{j.execution_node} is not a registered instance; reaping {j.log_format}')
|
||||
reap_job(j, 'failed')
|
||||
|
||||
# Reset waiting jobs whose controller_node was deprovisioned (e.g. K8s pod replaced).
|
||||
# These jobs will never be picked up because no live node is listening for them.
|
||||
registered_control_nodes = Instance.objects.filter(node_type__in=('control', 'hybrid')).values_list('hostname', flat=True)
|
||||
orphaned_waiting = UnifiedJob.objects.filter(status='waiting').exclude(controller_node__in=registered_control_nodes)
|
||||
for j in orphaned_waiting:
|
||||
logger.warning(f'{j.controller_node} is not a registered instance; resetting {j.log_format} to pending')
|
||||
j.status = 'pending'
|
||||
j.controller_node = ''
|
||||
j.execution_node = ''
|
||||
j.save(update_fields=['status', 'controller_node', 'execution_node'])
|
||||
|
||||
def process_tasks(self):
|
||||
# maintain a list of jobs that went to an early failure state,
|
||||
# meaning the dispatcher never got these jobs,
|
||||
|
||||
@@ -36,7 +36,6 @@ from awx.main.models import (
|
||||
Inventory,
|
||||
InventorySource,
|
||||
Job,
|
||||
JobHostSummary,
|
||||
Organization,
|
||||
Project,
|
||||
Role,
|
||||
@@ -251,45 +250,9 @@ def migrate_children_from_deleted_group_to_parent_groups(sender, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# Update host pointers to last_job and last_job_host_summary when a job is deleted
|
||||
|
||||
|
||||
def _update_host_last_jhs(host):
|
||||
jhs_qs = JobHostSummary.objects.filter(host__pk=host.pk)
|
||||
try:
|
||||
jhs = jhs_qs.order_by('-job__pk')[0]
|
||||
except IndexError:
|
||||
jhs = None
|
||||
update_fields = []
|
||||
try:
|
||||
last_job = jhs.job if jhs else None
|
||||
except Job.DoesNotExist:
|
||||
# The job (and its summaries) have already been/are currently being
|
||||
# deleted, so there's no need to update the host w/ a reference to it
|
||||
return
|
||||
if host.last_job != last_job:
|
||||
host.last_job = last_job
|
||||
update_fields.append('last_job')
|
||||
if host.last_job_host_summary != jhs:
|
||||
host.last_job_host_summary = jhs
|
||||
update_fields.append('last_job_host_summary')
|
||||
if update_fields:
|
||||
host.save(update_fields=update_fields)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=Job)
|
||||
def save_host_pks_before_job_delete(sender, **kwargs):
|
||||
instance = kwargs['instance']
|
||||
hosts_qs = Host.objects.filter(last_job__pk=instance.pk)
|
||||
instance._saved_hosts_pks = set(hosts_qs.values_list('pk', flat=True))
|
||||
|
||||
|
||||
@receiver(post_delete, sender=Job)
|
||||
def update_host_last_job_after_job_deleted(sender, **kwargs):
|
||||
instance = kwargs['instance']
|
||||
hosts_pks = getattr(instance, '_saved_hosts_pks', [])
|
||||
for host in Host.objects.filter(pk__in=hosts_pks):
|
||||
_update_host_last_jhs(host)
|
||||
# Host.last_job and Host.last_job_host_summary are now derived from
|
||||
# JobHostSummary.latest_for_host / latest_job_for_host.
|
||||
# No signal handlers needed to maintain these denormalized FKs.
|
||||
|
||||
|
||||
# Set via ActivityStreamRegistrar to record activity stream events
|
||||
|
||||
@@ -54,9 +54,6 @@ def try_load_query_file(artifact_dir) -> Tuple[bool, Optional[dict]]:
|
||||
returns the contents of ansible_data.json if present
|
||||
"""
|
||||
|
||||
if not flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
|
||||
return False, None
|
||||
|
||||
queries_path = os.path.join(artifact_dir, COLLECTION_FILENAME)
|
||||
if not os.path.isfile(queries_path):
|
||||
logger.info(f"no query file found: {queries_path}")
|
||||
@@ -277,20 +274,6 @@ class RunnerCallback:
|
||||
def artifacts_handler(self, artifact_dir):
|
||||
success, query_file_contents = try_load_query_file(artifact_dir)
|
||||
if success:
|
||||
self.delay_update(event_queries_processed=False)
|
||||
collections_info = collect_queries(query_file_contents)
|
||||
for collection, data in collections_info.items():
|
||||
version = data['version']
|
||||
event_query = data['host_query']
|
||||
instance = EventQuery(fqcn=collection, collection_version=version, event_query=event_query)
|
||||
try:
|
||||
instance.validate_unique()
|
||||
instance.save()
|
||||
|
||||
logger.info(f"eventy query for collection {collection}, version {version} created")
|
||||
except ValidationError as e:
|
||||
logger.info(e)
|
||||
|
||||
if 'installed_collections' in query_file_contents:
|
||||
self.delay_update(installed_collections=query_file_contents['installed_collections'])
|
||||
else:
|
||||
@@ -301,6 +284,21 @@ class RunnerCallback:
|
||||
else:
|
||||
logger.warning(f'The file {COLLECTION_FILENAME} unexpectedly did not contain ansible_version')
|
||||
|
||||
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
|
||||
self.delay_update(event_queries_processed=False)
|
||||
collections_info = collect_queries(query_file_contents)
|
||||
for collection, data in collections_info.items():
|
||||
version = data['version']
|
||||
event_query = data['host_query']
|
||||
instance = EventQuery(fqcn=collection, collection_version=version, event_query=event_query)
|
||||
try:
|
||||
instance.validate_unique()
|
||||
instance.save()
|
||||
|
||||
logger.info(f"event query for collection {collection}, version {version} created")
|
||||
except ValidationError as e:
|
||||
logger.info(e)
|
||||
|
||||
self.artifacts_processed = True
|
||||
|
||||
|
||||
|
||||
@@ -99,64 +99,99 @@ def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, jo
|
||||
try:
|
||||
with open(summary_path, 'r', encoding='utf-8') as f:
|
||||
summary = json.load(f)
|
||||
facts_write_time = os.path.getmtime(summary_path) # After successful read
|
||||
facts_write_time = os.path.getmtime(summary_path)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.error(f'Error reading summary file at {summary_path}: {e}')
|
||||
return
|
||||
|
||||
hosts_cached_map = summary.get('hosts_cached', {})
|
||||
host_names = list(hosts_cached_map.keys())
|
||||
hosts_cached = host_qs.filter(name__in=host_names).order_by('id').iterator()
|
||||
# Path where individual fact files were written
|
||||
fact_cache_dir = os.path.join(artifacts_dir, 'fact_cache')
|
||||
hosts_to_update = []
|
||||
|
||||
for host in hosts_cached:
|
||||
filepath = os.path.join(fact_cache_dir, host.name)
|
||||
if not os.path.realpath(filepath).startswith(fact_cache_dir):
|
||||
logger.error(f'Invalid path for facts file: {filepath}')
|
||||
continue
|
||||
# Phase 1: Scan files on disk to discover which hosts have updated or missing facts
|
||||
hosts_with_updates = set() # hostnames whose fact file was modified by Ansible
|
||||
hosts_to_clear = [] # hostnames where Ansible removed the fact file
|
||||
seen_in_dir = set() # hostnames we found as files on disk
|
||||
|
||||
if os.path.exists(filepath):
|
||||
# If the file changed since we wrote the last facts file, pre-playbook run...
|
||||
modified = os.path.getmtime(filepath)
|
||||
if not facts_write_time or modified >= facts_write_time:
|
||||
try:
|
||||
with codecs.open(filepath, 'r', encoding='utf-8') as f:
|
||||
ansible_facts = json.load(f)
|
||||
except ValueError:
|
||||
continue
|
||||
if os.path.isdir(fact_cache_dir):
|
||||
for filename in os.listdir(fact_cache_dir):
|
||||
if filename not in hosts_cached_map:
|
||||
continue # not an expected host for this job
|
||||
|
||||
if ansible_facts != host.ansible_facts:
|
||||
host.ansible_facts = ansible_facts
|
||||
host.ansible_facts_modified = now()
|
||||
hosts_to_update.append(host)
|
||||
logger.info(
|
||||
f'New fact for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}',
|
||||
extra=dict(
|
||||
inventory_id=host.inventory.id,
|
||||
host_name=host.name,
|
||||
ansible_facts=host.ansible_facts,
|
||||
ansible_facts_modified=host.ansible_facts_modified.isoformat(),
|
||||
job_id=job_id,
|
||||
),
|
||||
)
|
||||
log_data['updated_ct'] += 1
|
||||
else:
|
||||
log_data['unmodified_ct'] += 1
|
||||
else:
|
||||
log_data['unmodified_ct'] += 1
|
||||
else:
|
||||
# File is missing. Only interpret this as "ansible cleared facts" if
|
||||
# start_fact_cache actually wrote a file for this host (i.e. the host
|
||||
# had valid, non-expired facts before the job ran). If no file was
|
||||
# ever written, the missing file is expected and not a clear signal.
|
||||
if not hosts_cached_map.get(host.name):
|
||||
log_data['unmodified_ct'] += 1
|
||||
filepath = os.path.join(fact_cache_dir, filename)
|
||||
if os.path.islink(filepath):
|
||||
logger.error(f'Invalid path for facts file: {filepath}')
|
||||
continue
|
||||
if not os.path.isfile(filepath):
|
||||
continue
|
||||
|
||||
# if the file goes missing, ansible removed it (likely via clear_facts)
|
||||
# if the file goes missing, but the host has not started facts, then we should not clear the facts
|
||||
seen_in_dir.add(filename)
|
||||
try:
|
||||
modified = os.path.getmtime(filepath)
|
||||
except OSError as e:
|
||||
logger.warning(f'Could not stat facts file {filepath}: {e}')
|
||||
continue
|
||||
if modified >= facts_write_time:
|
||||
hosts_with_updates.add(filename)
|
||||
else:
|
||||
log_data['unmodified_ct'] += 1
|
||||
|
||||
# Check for files we wrote pre-job that are now missing (Ansible cleared facts)
|
||||
for hostname, was_written in hosts_cached_map.items():
|
||||
if hostname in seen_in_dir:
|
||||
continue # already handled above
|
||||
if was_written:
|
||||
hosts_to_clear.append(hostname)
|
||||
else:
|
||||
log_data['unmodified_ct'] += 1
|
||||
|
||||
# Phase 2: Stream updated facts to database in batches
|
||||
if hosts_with_updates:
|
||||
hosts_to_save = []
|
||||
total_rows_updated = 0
|
||||
for host in host_qs.filter(name__in=list(hosts_with_updates)).select_related('inventory').iterator():
|
||||
filepath = os.path.join(fact_cache_dir, host.name)
|
||||
try:
|
||||
with codecs.open(filepath, 'r', encoding='utf-8') as f:
|
||||
new_facts = json.load(f)
|
||||
except (ValueError, OSError):
|
||||
continue
|
||||
|
||||
if new_facts != host.ansible_facts:
|
||||
host.ansible_facts = new_facts
|
||||
host.ansible_facts_modified = now()
|
||||
hosts_to_save.append(host)
|
||||
logger.info(
|
||||
f'New fact for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}',
|
||||
extra=dict(
|
||||
inventory_id=host.inventory.id,
|
||||
host_name=host.name,
|
||||
ansible_facts=host.ansible_facts,
|
||||
ansible_facts_modified=host.ansible_facts_modified.isoformat(),
|
||||
job_id=job_id,
|
||||
),
|
||||
)
|
||||
log_data['updated_ct'] += 1
|
||||
else:
|
||||
log_data['unmodified_ct'] += 1
|
||||
|
||||
if len(hosts_to_save) >= 100:
|
||||
total_rows_updated += bulk_update_sorted_by_id(Host, hosts_to_save, fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
hosts_to_save = []
|
||||
|
||||
if hosts_to_save:
|
||||
total_rows_updated += bulk_update_sorted_by_id(Host, hosts_to_save, fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
|
||||
# Mismatch means a concurrent process changed or deleted hosts between our read and bulk update
|
||||
if total_rows_updated != log_data['updated_ct']:
|
||||
logger.warning(
|
||||
f'Fact update for inventory {inventory_id} job {job_id}: expected to update {log_data["updated_ct"]} hosts but {total_rows_updated} rows were changed'
|
||||
)
|
||||
|
||||
# Phase 3: Clear facts for hosts whose files were removed by Ansible
|
||||
if hosts_to_clear:
|
||||
hosts = list(host_qs.filter(name__in=hosts_to_clear).select_related('inventory'))
|
||||
clear_hosts = []
|
||||
for host in hosts:
|
||||
if job_created and host.ansible_facts_modified and host.ansible_facts_modified > job_created:
|
||||
logger.warning(
|
||||
f'Skipping fact clear for host {smart_str(host.name)} in job {job_id} '
|
||||
@@ -169,13 +204,13 @@ def finish_fact_cache(host_qs, artifacts_dir, job_id=None, inventory_id=None, jo
|
||||
else:
|
||||
host.ansible_facts = {}
|
||||
host.ansible_facts_modified = now()
|
||||
hosts_to_update.append(host)
|
||||
clear_hosts.append(host)
|
||||
logger.info(f'Facts cleared for inventory {smart_str(host.inventory.name)} host {smart_str(host.name)}')
|
||||
log_data['cleared_ct'] += 1
|
||||
|
||||
if len(hosts_to_update) >= 100:
|
||||
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
hosts_to_update = []
|
||||
if clear_hosts:
|
||||
rows = bulk_update_sorted_by_id(Host, clear_hosts, fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
if rows != len(clear_hosts):
|
||||
logger.warning(f'Fact clear for inventory {inventory_id} job {job_id}: expected to clear {len(clear_hosts)} hosts but {rows} rows were changed')
|
||||
|
||||
bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
|
||||
logger.debug(f'Updated {log_data["updated_ct"]} host facts for inventory {inventory_id} in job {job_id}')
|
||||
|
||||
@@ -94,7 +94,7 @@ from flags.state import flag_enabled
|
||||
|
||||
# Workload Identity
|
||||
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
|
||||
from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client
|
||||
from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims
|
||||
|
||||
logger = logging.getLogger('awx.main.tasks.jobs')
|
||||
|
||||
@@ -168,14 +168,12 @@ def retrieve_workload_identity_jwt(
|
||||
Raises:
|
||||
RuntimeError: if the workload identity client is not configured.
|
||||
"""
|
||||
client = get_workload_identity_client()
|
||||
if client is None:
|
||||
raise RuntimeError("Workload identity client is not configured")
|
||||
claims = populate_claims_for_workload(unified_job)
|
||||
kwargs = {"claims": claims, "scope": scope, "audience": audience}
|
||||
if workload_ttl_seconds:
|
||||
kwargs["workload_ttl_seconds"] = workload_ttl_seconds
|
||||
return client.request_workload_jwt(**kwargs).jwt
|
||||
return retrieve_workload_identity_jwt_with_claims(
|
||||
populate_claims_for_workload(unified_job),
|
||||
audience,
|
||||
scope,
|
||||
workload_ttl_seconds,
|
||||
)
|
||||
|
||||
|
||||
def with_path_cleanup(f):
|
||||
@@ -230,16 +228,19 @@ class BaseTask(object):
|
||||
# Convert to list to prevent re-evaluation of QuerySet
|
||||
return list(credentials_list)
|
||||
|
||||
def populate_workload_identity_tokens(self):
|
||||
def populate_workload_identity_tokens(self, additional_credentials=None):
|
||||
"""
|
||||
Populate credentials with workload identity tokens.
|
||||
|
||||
Sets the context on Credential objects that have input sources
|
||||
using compatible external credential types.
|
||||
"""
|
||||
credentials = list(self._credentials)
|
||||
if additional_credentials:
|
||||
credentials.extend(additional_credentials)
|
||||
credential_input_sources = (
|
||||
(credential.context, src)
|
||||
for credential in self._credentials
|
||||
for credential in credentials
|
||||
for src in credential.input_sources.all()
|
||||
if any(
|
||||
field.get('id') == 'workload_identity_token' and field.get('internal')
|
||||
@@ -253,7 +254,7 @@ class BaseTask(object):
|
||||
try:
|
||||
jwt = retrieve_workload_identity_jwt(
|
||||
self.instance,
|
||||
audience=input_src.source_credential.get_input('jwt_aud'),
|
||||
audience=input_src.source_credential.get_input('url'),
|
||||
scope=AutomationControllerJobScope.name,
|
||||
workload_ttl_seconds=workload_ttl,
|
||||
)
|
||||
@@ -1137,12 +1138,11 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
('ANSIBLE_COLLECTIONS_PATH', 'collections_path', 'requirements_collections', '~/.ansible/collections:/usr/share/ansible/collections'),
|
||||
]
|
||||
|
||||
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
|
||||
path_vars.append(
|
||||
('ANSIBLE_CALLBACK_PLUGINS', 'callback_plugins', 'plugins_path', '~/.ansible/plugins:/plugins/callback:/usr/share/ansible/plugins/callback'),
|
||||
)
|
||||
path_vars.append(
|
||||
('ANSIBLE_CALLBACK_PLUGINS', 'callback_plugins', 'plugins_path', '~/.ansible/plugins:/plugins/callback:/usr/share/ansible/plugins/callback'),
|
||||
)
|
||||
|
||||
config_values = read_ansible_config(os.path.join(private_data_dir, 'project'), list(map(lambda x: x[1], path_vars)))
|
||||
config_values = read_ansible_config(os.path.join(private_data_dir, 'project'), list(map(lambda x: x[1], path_vars)) + ['callbacks_enabled'])
|
||||
|
||||
for env_key, config_setting, folder, default in path_vars:
|
||||
paths = default.split(':')
|
||||
@@ -1157,11 +1157,12 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
paths = [os.path.join(CONTAINER_ROOT, folder)] + paths
|
||||
env[env_key] = os.pathsep.join(paths)
|
||||
|
||||
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
|
||||
env['ANSIBLE_CALLBACKS_ENABLED'] = 'indirect_instance_count'
|
||||
if 'callbacks_enabled' in config_values:
|
||||
env['ANSIBLE_CALLBACKS_ENABLED'] += ':' + config_values['callbacks_enabled']
|
||||
env['ANSIBLE_CALLBACKS_ENABLED'] = 'indirect_instance_count'
|
||||
if 'callbacks_enabled' in config_values:
|
||||
env['ANSIBLE_CALLBACKS_ENABLED'] += ',' + config_values['callbacks_enabled']
|
||||
|
||||
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
|
||||
env['AWX_COLLECT_HOST_QUERIES'] = '1'
|
||||
# Add vendor collections path for external query file discovery
|
||||
vendor_collections_path = os.path.join(CONTAINER_ROOT, 'vendor_collections')
|
||||
env['ANSIBLE_COLLECTIONS_PATH'] = f"{vendor_collections_path}:{env['ANSIBLE_COLLECTIONS_PATH']}"
|
||||
@@ -1330,6 +1331,7 @@ class RunJob(SourceControlMixin, BaseTask):
|
||||
hosts_qs = job.get_source_hosts_for_constructed_inventory()
|
||||
else:
|
||||
hosts_qs = job.inventory.hosts
|
||||
hosts_qs = hosts_qs.only(*HOST_FACTS_FIELDS)
|
||||
finish_fact_cache(
|
||||
hosts_qs,
|
||||
artifacts_dir=os.path.join(private_data_dir, 'artifacts', str(job.id)),
|
||||
@@ -1610,16 +1612,14 @@ class RunProjectUpdate(BaseTask):
|
||||
shutil.copytree(cache_subpath, dest_subpath, symlinks=True)
|
||||
logger.debug('{0} {1} prepared {2} from cache'.format(type(project).__name__, project.pk, dest_subpath))
|
||||
|
||||
if flag_enabled("FEATURE_INDIRECT_NODE_COUNTING_ENABLED"):
|
||||
# copy the special callback (not stdout type) plugin to get list of collections
|
||||
pdd_plugins_path = os.path.join(job_private_data_dir, 'plugins_path')
|
||||
if not os.path.exists(pdd_plugins_path):
|
||||
os.mkdir(pdd_plugins_path)
|
||||
from awx.playbooks import library
|
||||
pdd_plugins_path = os.path.join(job_private_data_dir, 'plugins_path')
|
||||
if not os.path.exists(pdd_plugins_path):
|
||||
os.mkdir(pdd_plugins_path)
|
||||
from awx.playbooks import library
|
||||
|
||||
plugin_file_source = os.path.join(library.__path__._path[0], 'indirect_instance_count.py')
|
||||
plugin_file_dest = os.path.join(pdd_plugins_path, 'indirect_instance_count.py')
|
||||
shutil.copyfile(plugin_file_source, plugin_file_dest)
|
||||
plugin_file_source = os.path.join(library.__path__[0], 'indirect_instance_count.py')
|
||||
plugin_file_dest = os.path.join(pdd_plugins_path, 'indirect_instance_count.py')
|
||||
shutil.copyfile(plugin_file_source, plugin_file_dest)
|
||||
|
||||
def post_run_hook(self, instance, status):
|
||||
super(RunProjectUpdate, self).post_run_hook(instance, status)
|
||||
@@ -1865,6 +1865,24 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
|
||||
# All credentials not used by inventory source injector
|
||||
return inventory_update.get_extra_credentials()
|
||||
|
||||
def populate_workload_identity_tokens(self, additional_credentials=None):
|
||||
"""Also generate OIDC tokens for the cloud credential.
|
||||
|
||||
The cloud credential is not in _credentials (it is handled by the
|
||||
inventory source injector), but it may still need a workload identity
|
||||
token generated for it.
|
||||
"""
|
||||
cloud_cred = self.instance.get_cloud_credential()
|
||||
creds = list(additional_credentials or [])
|
||||
if cloud_cred:
|
||||
creds.append(cloud_cred)
|
||||
super().populate_workload_identity_tokens(additional_credentials=creds or None)
|
||||
# Override get_cloud_credential on this instance so the injector
|
||||
# uses the credential with OIDC context instead of doing a fresh
|
||||
# DB fetch that would lose it.
|
||||
if cloud_cred and cloud_cred.context:
|
||||
self.instance.get_cloud_credential = lambda: cloud_cred
|
||||
|
||||
def build_project_dir(self, inventory_update, private_data_dir):
|
||||
source_project = None
|
||||
if inventory_update.inventory_source:
|
||||
|
||||
@@ -19,6 +19,7 @@ from dispatcherd.publish import task
|
||||
# Runner
|
||||
import ansible_runner.cleanup
|
||||
import psycopg
|
||||
from ansible_base.lib.cache.tasks import clear_cache as dab_clear_cache
|
||||
from ansible_base.lib.utils.db import advisory_lock
|
||||
|
||||
# django-ansible-base
|
||||
@@ -68,10 +69,12 @@ from awx.main.models import (
|
||||
UnifiedJob,
|
||||
convert_jsonfields,
|
||||
)
|
||||
from awx.main.models.credential import CredentialType
|
||||
from awx.main.tasks.helpers import is_run_threshold_reached
|
||||
from awx.main.tasks.host_indirect import save_indirect_host_entries
|
||||
from awx.main.tasks.receptor import administrative_workunit_reaper, get_receptor_ctl, worker_cleanup, worker_info, write_receptor_config
|
||||
from awx.main.utils.common import ignore_inventory_computed_fields, ignore_inventory_group_removal
|
||||
from awx.main.utils.migration import is_database_synchronized
|
||||
from awx.main.utils.reload import stop_local_services
|
||||
|
||||
logger = logging.getLogger('awx.main.tasks.system')
|
||||
@@ -83,6 +86,16 @@ Try upgrading OpenSSH or providing your private key in an different format. \
|
||||
'''
|
||||
|
||||
|
||||
def _sync_credential_types_to_db():
|
||||
"""Ensure CredentialType DB rows match the installed plugins.
|
||||
|
||||
The in-memory registry is populated lazily on first access via LazyLoadDict.
|
||||
This function only handles the DB sync step.
|
||||
"""
|
||||
if is_database_synchronized():
|
||||
CredentialType.setup_tower_managed_defaults()
|
||||
|
||||
|
||||
def _run_dispatch_startup_common():
|
||||
"""
|
||||
Execute the common startup initialization steps.
|
||||
@@ -98,6 +111,11 @@ def _run_dispatch_startup_common():
|
||||
except Exception:
|
||||
logger.exception("Failed to write receptor config, skipping.")
|
||||
|
||||
try:
|
||||
_sync_credential_types_to_db()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync credential types to DB, skipping.")
|
||||
|
||||
try:
|
||||
convert_jsonfields()
|
||||
except Exception:
|
||||
@@ -240,12 +258,17 @@ def apply_cluster_membership_policies():
|
||||
# Process policy instance list first, these will represent manually managed memberships
|
||||
instance_hostnames_map = {inst.hostname: inst for inst in all_instances}
|
||||
for ig in all_groups:
|
||||
# we don't want to allow execution nodes in the control plane
|
||||
exclude_type = 'execution' if ig.name == settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME else 'control'
|
||||
group_actual = Group(obj=ig, instances=[], prior_instances=[instance.pk for instance in ig.instances.all()]) # obtained in prefetch
|
||||
for hostname in ig.policy_instance_list:
|
||||
if hostname not in instance_hostnames_map:
|
||||
logger.info("Unknown instance {} in {} policy list".format(hostname, ig.name))
|
||||
continue
|
||||
inst = instance_hostnames_map[hostname]
|
||||
if inst.node_type == exclude_type:
|
||||
logger.info("Instance {} is excluded in {} policy list".format(hostname, ig.name))
|
||||
continue
|
||||
group_actual.instances.append(inst.id)
|
||||
# NOTE: arguable behavior: policy-list-group is not added to
|
||||
# instance's group count for consideration in minimum-policy rules
|
||||
@@ -326,24 +349,22 @@ def apply_cluster_membership_policies():
|
||||
logger.debug('Cluster policy computation finished in {} seconds'.format(time.time() - started_compute))
|
||||
|
||||
|
||||
@task(queue='tower_settings_change', timeout=600)
|
||||
def clear_setting_cache(setting_keys):
|
||||
# log that cache is being cleared
|
||||
logger.info(f"clear_setting_cache of keys {setting_keys}")
|
||||
orig_len = len(setting_keys)
|
||||
for i in range(orig_len):
|
||||
for dependent_key in settings_registry.get_dependent_settings(setting_keys[i]):
|
||||
setting_keys.append(dependent_key)
|
||||
cache_keys = set(setting_keys)
|
||||
logger.debug('cache delete_many(%r)', cache_keys)
|
||||
cache.delete_many(cache_keys)
|
||||
def _resolve_setting_dependents(key):
|
||||
return settings_registry.get_dependent_settings(key)
|
||||
|
||||
if 'LOG_AGGREGATOR_LEVEL' in setting_keys:
|
||||
|
||||
def _post_setting_invalidation(invalidated_keys):
|
||||
if 'LOG_AGGREGATOR_LEVEL' in invalidated_keys:
|
||||
ctl = get_control_from_settings()
|
||||
ctl.queuename = get_task_queuename()
|
||||
ctl.control('set_log_level', data={'level': settings.LOG_AGGREGATOR_LEVEL})
|
||||
|
||||
|
||||
@task(queue='tower_settings_change', timeout=600)
|
||||
def clear_setting_cache(setting_keys):
|
||||
dab_clear_cache(setting_keys, _resolve_setting_dependents, _post_setting_invalidation)
|
||||
|
||||
|
||||
@task(queue='tower_broadcast_all', timeout=600)
|
||||
def delete_project_files(project_path):
|
||||
# TODO: possibly implement some retry logic
|
||||
|
||||
11
awx/main/tests/data/projects/debug/set_stats.yml
Normal file
11
awx/main/tests/data/projects/debug/set_stats.yml
Normal file
@@ -0,0 +1,11 @@
|
||||
---
|
||||
- hosts: all
|
||||
gather_facts: false
|
||||
connection: local
|
||||
tasks:
|
||||
- name: Set artifacts via set_stats
|
||||
ansible.builtin.set_stats:
|
||||
data: "{{ stats_data }}"
|
||||
per_host: false
|
||||
aggregate: false
|
||||
when: stats_data is defined
|
||||
@@ -74,9 +74,9 @@ def temp_analytic_tar():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_analytic_post():
|
||||
# Patch the Session.post method to return a mock response with status_code 200
|
||||
with mock.patch('awx.main.analytics.core.requests.Session.post', return_value=mock.Mock(status_code=200)) as mock_post:
|
||||
yield mock_post
|
||||
# Patch get_or_generate_candlepin_certificate to skip mTLS path
|
||||
with mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate', return_value=(None, None)):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -141,15 +141,22 @@ def mock_analytic_post():
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
def test_ship_credential(setting_map, expected_result, expected_auth, temp_analytic_tar, mock_analytic_post):
|
||||
with override_settings(**setting_map):
|
||||
result = ship(temp_analytic_tar)
|
||||
with override_settings(**setting_map, AUTOMATION_ANALYTICS_URL='https://example.com/api'):
|
||||
with mock.patch('awx.main.analytics.core.OIDCClient') as mock_oidc:
|
||||
mock_oidc_instance = mock.Mock()
|
||||
mock_oidc_instance.make_request.return_value = mock.Mock(status_code=200)
|
||||
mock_oidc.return_value = mock_oidc_instance
|
||||
|
||||
assert result == expected_result
|
||||
if expected_auth:
|
||||
mock_analytic_post.assert_called_once()
|
||||
assert mock_analytic_post.call_args[1]['auth'] == expected_auth
|
||||
else:
|
||||
mock_analytic_post.assert_not_called()
|
||||
result = ship(temp_analytic_tar)
|
||||
|
||||
assert result == expected_result
|
||||
if expected_auth:
|
||||
# Verify OIDC client was instantiated with correct credentials
|
||||
mock_oidc.assert_called_once_with(expected_auth[0], expected_auth[1])
|
||||
mock_oidc_instance.make_request.assert_called_once()
|
||||
else:
|
||||
# When credentials are missing, OIDCClient should not be called
|
||||
mock_oidc.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
||||
84
awx/main/tests/functional/api/test_config_endpoint.py
Normal file
84
awx/main/tests/functional/api/test_config_endpoint.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
from awx.api.versioning import reverse
|
||||
from rest_framework import status
|
||||
|
||||
from awx.main.models.jobs import JobTemplate
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestConfigEndpointFields:
|
||||
def test_base_fields_all_users(self, get, rando):
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, rando, expect=200)
|
||||
|
||||
assert 'time_zone' in response.data
|
||||
assert 'license_info' in response.data
|
||||
assert 'version' in response.data
|
||||
assert 'eula' in response.data
|
||||
assert 'analytics_status' in response.data
|
||||
assert 'analytics_collectors' in response.data
|
||||
assert 'become_methods' in response.data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"role_type",
|
||||
[
|
||||
"superuser",
|
||||
"system_auditor",
|
||||
"org_admin",
|
||||
"org_auditor",
|
||||
"org_project_admin",
|
||||
],
|
||||
)
|
||||
def test_privileged_users_conditional_fields(self, get, user, organization, admin, role_type):
|
||||
url = reverse('api:api_v2_config_view')
|
||||
|
||||
if role_type == "superuser":
|
||||
test_user = admin
|
||||
elif role_type == "system_auditor":
|
||||
test_user = user('system-auditor', is_superuser=False)
|
||||
test_user.is_system_auditor = True
|
||||
test_user.save()
|
||||
elif role_type == "org_admin":
|
||||
test_user = user('org-admin', is_superuser=False)
|
||||
organization.admin_role.members.add(test_user)
|
||||
elif role_type == "org_auditor":
|
||||
test_user = user('org-auditor', is_superuser=False)
|
||||
organization.auditor_role.members.add(test_user)
|
||||
elif role_type == "org_project_admin":
|
||||
test_user = user('org-project-admin', is_superuser=False)
|
||||
organization.project_admin_role.members.add(test_user)
|
||||
|
||||
response = get(url, test_user, expect=200)
|
||||
|
||||
assert 'project_base_dir' in response.data
|
||||
assert 'project_local_paths' in response.data
|
||||
assert 'custom_virtualenvs' in response.data
|
||||
|
||||
def test_job_template_admin_gets_venvs_only(self, get, user, organization, project, inventory):
|
||||
"""Test that JobTemplate admin without org access gets only custom_virtualenvs"""
|
||||
jt_admin = user('jt-admin', is_superuser=False)
|
||||
|
||||
jt = JobTemplate.objects.create(name='test-jt', organization=organization, project=project, inventory=inventory)
|
||||
jt.admin_role.members.add(jt_admin)
|
||||
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, jt_admin, expect=200)
|
||||
|
||||
assert 'custom_virtualenvs' in response.data
|
||||
assert 'project_base_dir' not in response.data
|
||||
assert 'project_local_paths' not in response.data
|
||||
|
||||
def test_normal_user_no_conditional_fields(self, get, rando):
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, rando, expect=200)
|
||||
|
||||
assert 'project_base_dir' not in response.data
|
||||
assert 'project_local_paths' not in response.data
|
||||
assert 'custom_virtualenvs' not in response.data
|
||||
|
||||
def test_unauthenticated_denied(self, get):
|
||||
"""Test that unauthenticated requests are denied"""
|
||||
url = reverse('api:api_v2_config_view')
|
||||
response = get(url, None, expect=401)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -200,6 +200,7 @@ def test_grant_org_credential_to_org_user_through_user_roles(post, credential, o
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_org_credential_to_non_org_user_through_role_users(post, credential, organization, org_admin, alice):
|
||||
# NOTE: this endpoint is going away soon
|
||||
credential.organization = organization
|
||||
credential.save()
|
||||
response = post(reverse('api:role_users_list', kwargs={'pk': credential.use_role.id}), {'id': alice.id}, org_admin)
|
||||
@@ -208,6 +209,7 @@ def test_grant_org_credential_to_non_org_user_through_role_users(post, credentia
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_org_credential_to_non_org_user_through_user_roles(post, credential, organization, org_admin, alice):
|
||||
# NOTE: this endpoint is going away soon
|
||||
credential.organization = organization
|
||||
credential.save()
|
||||
response = post(reverse('api:user_roles_list', kwargs={'pk': alice.id}), {'id': credential.use_role.id}, org_admin)
|
||||
@@ -216,18 +218,18 @@ def test_grant_org_credential_to_non_org_user_through_user_roles(post, credentia
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_private_credential_to_user_through_role_users(post, credential, alice, bob):
|
||||
# normal users can't do this
|
||||
# NOTE: this endpoint is going away soon
|
||||
credential.admin_role.members.add(alice)
|
||||
response = post(reverse('api:role_users_list', kwargs={'pk': credential.use_role.id}), {'id': bob.id}, alice)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_private_credential_to_org_user_through_role_users(post, credential, org_admin, org_member):
|
||||
# org admins can't either
|
||||
# NOTE: this endpoint is going away soon
|
||||
credential.admin_role.members.add(org_admin)
|
||||
response = post(reverse('api:role_users_list', kwargs={'pk': credential.use_role.id}), {'id': org_member.id}, org_admin)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -239,18 +241,18 @@ def test_sa_grant_private_credential_to_user_through_role_users(post, credential
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_private_credential_to_user_through_user_roles(post, credential, alice, bob):
|
||||
# normal users can't do this
|
||||
# NOTE: this endpoint is going away soon
|
||||
credential.admin_role.members.add(alice)
|
||||
response = post(reverse('api:user_roles_list', kwargs={'pk': bob.id}), {'id': credential.use_role.id}, alice)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_private_credential_to_org_user_through_user_roles(post, credential, org_admin, org_member):
|
||||
# org admins can't either
|
||||
# NOTE: this endpoint is going away soon
|
||||
credential.admin_role.members.add(org_admin)
|
||||
response = post(reverse('api:user_roles_list', kwargs={'pk': org_member.id}), {'id': credential.use_role.id}, org_admin)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -282,14 +284,14 @@ def test_grant_org_credential_to_team_through_team_roles(post, credential, organ
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_sa_grant_private_credential_to_team_through_role_teams(post, credential, admin, team):
|
||||
# not even a system admin can grant a private cred to a team though
|
||||
# NOTE: this endpoint is going away soon
|
||||
response = post(reverse('api:role_teams_list', kwargs={'pk': credential.use_role.id}), {'id': team.id}, admin)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_credential_to_team_different_organization_through_role_teams(post, get, credential, organizations, admin, org_admin, team, team_member):
|
||||
# # Test that credential from different org can be assigned to team by a superuser through role_teams_list endpoint
|
||||
# NOTE: this endpoint is going away soon
|
||||
orgs = organizations(2)
|
||||
credential.organization = orgs[0]
|
||||
credential.save()
|
||||
@@ -299,10 +301,7 @@ def test_grant_credential_to_team_different_organization_through_role_teams(post
|
||||
# Non-superuser (org_admin) trying cross-org assignment should be denied
|
||||
response = post(reverse('api:role_teams_list', kwargs={'pk': credential.use_role.id}), {'id': team.id}, org_admin)
|
||||
assert response.status_code == 400
|
||||
assert (
|
||||
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
|
||||
in response.data['msg']
|
||||
)
|
||||
assert "You cannot grant credential access to a Team not in the credentials' organization" in str(response.data['detail'])
|
||||
|
||||
# Superuser (admin) can do cross-org assignment
|
||||
response = post(reverse('api:role_teams_list', kwargs={'pk': credential.use_role.id}), {'id': team.id}, admin)
|
||||
@@ -316,20 +315,17 @@ def test_grant_credential_to_team_different_organization_through_role_teams(post
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_credential_to_team_different_organization(post, get, credential, organizations, admin, org_admin, team, team_member):
|
||||
# Test that credential from different org can be assigned to team by a superuser
|
||||
# NOTE: this endpoint is going away soon
|
||||
orgs = organizations(2)
|
||||
credential.organization = orgs[0]
|
||||
credential.save()
|
||||
team.organization = orgs[1]
|
||||
team.save()
|
||||
|
||||
# Non-superuser (org_admin, ...) trying cross-org assignment should be denied
|
||||
# Non-superuser (org_admin) trying cross-org assignment should be denied
|
||||
response = post(reverse('api:team_roles_list', kwargs={'pk': team.id}), {'id': credential.use_role.id}, org_admin)
|
||||
assert response.status_code == 400
|
||||
assert (
|
||||
"You cannot grant a team access to a credential in a different organization. Only superusers can grant cross-organization credential access to teams"
|
||||
in response.data['msg']
|
||||
)
|
||||
assert "You cannot grant credential access to a Team not in the credentials' organization" in str(response.data['detail'])
|
||||
|
||||
# Superuser (system admin) can do cross-org assignment
|
||||
response = post(reverse('api:team_roles_list', kwargs={'pk': team.id}), {'id': credential.use_role.id}, admin)
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from ansible_base.lib.testing.util import feature_flag_enabled
|
||||
from awx.main.models.credential import CredentialType, Credential
|
||||
from awx.api.versioning import reverse
|
||||
|
||||
@@ -159,7 +160,8 @@ def test_create_as_admin(get, post, admin):
|
||||
response = get(reverse('api:credential_type_list'), admin)
|
||||
assert response.data['count'] == 1
|
||||
assert response.data['results'][0]['name'] == 'Custom Credential Type'
|
||||
assert response.data['results'][0]['inputs'] == {}
|
||||
# Serializer normalizes empty inputs to {'fields': []}
|
||||
assert response.data['results'][0]['inputs'] == {'fields': []}
|
||||
assert response.data['results'][0]['injectors'] == {}
|
||||
assert response.data['results'][0]['managed'] is False
|
||||
|
||||
@@ -474,3 +476,98 @@ def test_credential_type_rbac_external_test(post, alice, admin, credentialtype_e
|
||||
data = {'inputs': {}, 'metadata': {}}
|
||||
assert post(url, data, admin).status_code == 202
|
||||
assert post(url, data, alice).status_code == 403
|
||||
|
||||
|
||||
# --- Tests for internal field filtering with None/invalid inputs ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_with_none_inputs(get, admin):
|
||||
"""Test that credential type with empty inputs dict works correctly."""
|
||||
# Create a credential type with empty dict
|
||||
ct = CredentialType.objects.create(
|
||||
kind='cloud',
|
||||
name='Test Type',
|
||||
managed=False,
|
||||
inputs={}, # Empty dict, not None (DB has NOT NULL constraint)
|
||||
)
|
||||
|
||||
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
# Should have normalized inputs to empty dict
|
||||
assert 'inputs' in response.data
|
||||
assert isinstance(response.data['inputs'], dict)
|
||||
assert response.data['inputs']['fields'] == []
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_with_invalid_inputs_type(get, admin):
|
||||
"""Test that credential type with non-dict inputs doesn't cause errors."""
|
||||
# Create a credential type with invalid inputs type
|
||||
ct = CredentialType.objects.create(kind='cloud', name='Test Type', managed=False, inputs={'fields': 'not-a-list'})
|
||||
|
||||
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
# Should gracefully handle invalid fields type
|
||||
assert 'inputs' in response.data
|
||||
assert response.data['inputs']['fields'] == []
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_filters_internal_fields(get, admin):
|
||||
"""Test that internal fields are filtered from API responses."""
|
||||
ct = CredentialType.objects.create(
|
||||
kind='cloud',
|
||||
name='Test OIDC Type',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'url', 'label': 'URL', 'type': 'string'},
|
||||
{'id': 'token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True},
|
||||
{'id': 'public_field', 'label': 'Public', 'type': 'string'},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk})
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
|
||||
field_ids = [f['id'] for f in response.data['inputs']['fields']]
|
||||
# Internal field should be filtered out
|
||||
assert 'token' not in field_ids
|
||||
assert 'url' in field_ids
|
||||
assert 'public_field' in field_ids
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_type_list_filters_internal_fields(get, admin):
|
||||
"""Test that internal fields are filtered in list view."""
|
||||
CredentialType.objects.create(
|
||||
kind='cloud',
|
||||
name='Test OIDC Type',
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'url', 'label': 'URL', 'type': 'string'},
|
||||
{'id': 'workload_identity_token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
url = reverse('api:credential_type_list')
|
||||
with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
|
||||
response = get(url, admin)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Find our credential type in the results
|
||||
test_ct = next((ct for ct in response.data['results'] if ct['name'] == 'Test OIDC Type'), None)
|
||||
assert test_ct is not None
|
||||
|
||||
field_ids = [f['id'] for f in test_ct['inputs']['fields']]
|
||||
# Internal field should be filtered out
|
||||
assert 'workload_identity_token' not in field_ids
|
||||
assert 'url' in field_ids
|
||||
|
||||
34
awx/main/tests/functional/api/test_dashboard.py
Normal file
34
awx/main/tests/functional/api/test_dashboard.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.models import Host, Inventory
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_dashboard_hosts_total_excludes_constructed(get, admin_user, organization):
|
||||
"""
|
||||
Constructed inventory hosts are not counted in the dashboard
|
||||
"""
|
||||
source_inv = Inventory.objects.create(name='source-inv', organization=organization)
|
||||
source_host = source_inv.hosts.create(name='host1')
|
||||
|
||||
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
|
||||
Host.objects.create(name='host1', inventory=constructed, instance_id=str(source_host.pk))
|
||||
|
||||
response = get(reverse('api:dashboard_view'), user=admin_user, expect=200)
|
||||
assert response.data['hosts']['total'] == 1
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_host_list_still_returns_constructed(get, admin_user, organization):
|
||||
"""
|
||||
Constructed inventory hosts are still visible through the API
|
||||
"""
|
||||
source_inv = Inventory.objects.create(name='source-inv', organization=organization)
|
||||
source_host = source_inv.hosts.create(name='host1')
|
||||
|
||||
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
|
||||
Host.objects.create(name='host1', inventory=constructed, instance_id=str(source_host.pk))
|
||||
|
||||
response = get(reverse('api:host_list'), user=admin_user, expect=200)
|
||||
assert response.data['count'] == 2
|
||||
@@ -1,5 +1,3 @@
|
||||
# TODO: As of writing this our only concern is ensuring that the fact feature is reflected in the Host endpoint.
|
||||
# Other host tests should live here to make this test suite more complete.
|
||||
import pytest
|
||||
import urllib.parse
|
||||
|
||||
@@ -20,6 +18,48 @@ def inventory_structure():
|
||||
Group.objects.create(name="g3", inventory=inv)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def host_filter_inventory():
|
||||
"""Inventory with hosts and groups matching the tower-qa test_host_filter structure.
|
||||
|
||||
Groups: groupA (contains groupAA as child), groupAA, groupB
|
||||
Hosts: hostA (in groupA), hostAA (in groupAA), hostB (in groupB), hostDup (in all 3 groups)
|
||||
"""
|
||||
org = Organization.objects.create(name="hf-org")
|
||||
inv = Inventory.objects.create(name="hf-inv", organization=org)
|
||||
|
||||
groupA = Group.objects.create(name="groupA", inventory=inv)
|
||||
groupAA = Group.objects.create(name="groupAA", inventory=inv)
|
||||
groupB = Group.objects.create(name="groupB", inventory=inv)
|
||||
|
||||
hostA = Host.objects.create(name="hostA", inventory=inv)
|
||||
hostAA = Host.objects.create(name="hostAA", inventory=inv)
|
||||
hostB = Host.objects.create(name="hostB", inventory=inv)
|
||||
hostDup = Host.objects.create(name="hostDup", inventory=inv)
|
||||
|
||||
groupA.hosts.add(hostA, hostDup)
|
||||
groupAA.hosts.add(hostAA, hostDup)
|
||||
groupB.hosts.add(hostB, hostDup)
|
||||
groupA.children.add(groupAA)
|
||||
|
||||
return {
|
||||
'org': org,
|
||||
'inv': inv,
|
||||
'hosts': {'hostA': hostA, 'hostAA': hostAA, 'hostB': hostB, 'hostDup': hostDup},
|
||||
'groups': {'groupA': groupA, 'groupAA': groupAA, 'groupB': groupB},
|
||||
}
|
||||
|
||||
|
||||
def get_host_names(response):
|
||||
return sorted(h['name'] for h in response.data['results'])
|
||||
|
||||
|
||||
def host_filter_get(get, user, host_filter):
|
||||
url = reverse('api:host_list')
|
||||
params = "?host_filter=%s" % urllib.parse.quote(host_filter, safe='')
|
||||
return get(url + params, user)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_q1(inventory_structure, get, user):
|
||||
def evaluate_query(query, expected_hosts):
|
||||
@@ -50,3 +90,184 @@ def test_q1(inventory_structure, get, user):
|
||||
# The following test verifies if the search in host_filter is case insensitive.
|
||||
query = 'search="HOST1"'
|
||||
evaluate_query(query, [hosts[0]])
|
||||
|
||||
|
||||
# --- Host filter query tests (migrated from tower-qa test_host_filter.py) ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"host_filter, expected",
|
||||
[
|
||||
("name=hostA", ["hostA"]),
|
||||
("name=not_found", []),
|
||||
("name=hostDup", ["hostDup"]),
|
||||
],
|
||||
)
|
||||
def test_basic_host_name_search(host_filter_inventory, get, admin_user, host_filter, expected):
|
||||
response = host_filter_get(get, admin_user, host_filter)
|
||||
assert response.status_code == 200
|
||||
assert get_host_names(response) == sorted(expected)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"host_filter, expected",
|
||||
[
|
||||
("name=hostA or name=hostB", ["hostA", "hostB"]),
|
||||
("name=hostA or name=not_found", ["hostA"]),
|
||||
("name=not_found or name=not_found", []),
|
||||
("name=hostA or name=hostA", ["hostA"]),
|
||||
("name=hostDup or name=hostDup", ["hostDup"]),
|
||||
("name=hostA or name=hostAA or name=not_found", ["hostA", "hostAA"]),
|
||||
],
|
||||
)
|
||||
def test_host_name_search_with_or(host_filter_inventory, get, admin_user, host_filter, expected):
|
||||
response = host_filter_get(get, admin_user, host_filter)
|
||||
assert response.status_code == 200
|
||||
assert get_host_names(response) == sorted(expected)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"host_filter, expected",
|
||||
[
|
||||
("name=hostA and name=hostB", []),
|
||||
("name=hostA and name=hostA", ["hostA"]),
|
||||
("name=not_found and name=not_found", []),
|
||||
("name=hostDup and name=hostDup", ["hostDup"]),
|
||||
("name=hostA and name=hostB and name=not_found", []),
|
||||
],
|
||||
)
|
||||
def test_host_name_search_with_and(host_filter_inventory, get, admin_user, host_filter, expected):
|
||||
response = host_filter_get(get, admin_user, host_filter)
|
||||
assert response.status_code == 200
|
||||
assert get_host_names(response) == sorted(expected)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"host_filter, expected",
|
||||
[
|
||||
("groups__name=groupA", ["hostA", "hostDup"]),
|
||||
("groups__name=groupAA", ["hostAA", "hostDup"]),
|
||||
("groups__name=not_found", []),
|
||||
],
|
||||
)
|
||||
def test_basic_group_search(host_filter_inventory, get, admin_user, host_filter, expected):
|
||||
response = host_filter_get(get, admin_user, host_filter)
|
||||
assert response.status_code == 200
|
||||
assert get_host_names(response) == sorted(expected)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"host_filter, expected",
|
||||
[
|
||||
("groups__name=groupA or groups__name=groupB", ["hostA", "hostB", "hostDup"]),
|
||||
("groups__name=groupA or groups__name=not_found", ["hostA", "hostDup"]),
|
||||
("groups__name=not_found or groups__name=not_found", []),
|
||||
("groups__name=groupA or groups__name=groupA", ["hostA", "hostDup"]),
|
||||
(
|
||||
"groups__name=groupA or groups__name=groupAA or groups__name=not_found",
|
||||
["hostA", "hostAA", "hostDup"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_group_search_with_or(host_filter_inventory, get, admin_user, host_filter, expected):
|
||||
response = host_filter_get(get, admin_user, host_filter)
|
||||
assert response.status_code == 200
|
||||
assert get_host_names(response) == sorted(expected)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"host_filter, expected",
|
||||
[
|
||||
("groups__name=groupA and groups__name=groupB", ["hostDup"]),
|
||||
("groups__name=groupA and groups__name=groupA", ["hostA", "hostDup"]),
|
||||
("groups__name=not_found and groups__name=not_found", []),
|
||||
("groups__name=groupA and groups__name=groupB and groups__name=not_found", []),
|
||||
],
|
||||
)
|
||||
def test_group_search_with_and(host_filter_inventory, get, admin_user, host_filter, expected):
|
||||
response = host_filter_get(get, admin_user, host_filter)
|
||||
assert response.status_code == 200
|
||||
assert get_host_names(response) == sorted(expected)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"host_filter, expected",
|
||||
[
|
||||
("name=hostA or groups__name=groupB", ["hostA", "hostB", "hostDup"]),
|
||||
("name=hostA and groups__name=groupA", ["hostA"]),
|
||||
("name=hostA and groups__name=not_found", []),
|
||||
("name=not_found and groups__name=not_found", []),
|
||||
("name=hostDup and groups__name=groupA", ["hostDup"]),
|
||||
("name=hostDup and groups__name=groupB", ["hostDup"]),
|
||||
],
|
||||
)
|
||||
def test_basic_hybrid_search(host_filter_inventory, get, admin_user, host_filter, expected):
|
||||
response = host_filter_get(get, admin_user, host_filter)
|
||||
assert response.status_code == 200
|
||||
assert get_host_names(response) == sorted(expected)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_smart_search(get, admin_user):
|
||||
org = Organization.objects.create(name="search-org")
|
||||
inv = Inventory.objects.create(name="search-inv", organization=org)
|
||||
host = Host.objects.create(name="unique_search_target", description="findme_description", inventory=inv)
|
||||
|
||||
for search_term in ["unique_search_target", "findme_description"]:
|
||||
response = host_filter_get(get, admin_user, "search=%s" % search_term)
|
||||
assert response.status_code == 200
|
||||
names = get_host_names(response)
|
||||
assert host.name in names
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_password_field_filter_blocked(get, admin_user):
|
||||
url = reverse('api:host_list')
|
||||
filters = [
|
||||
"created_by__password__icontains=pas3w3rd",
|
||||
"search=foo or created_by__password__icontains=pas3w3rd",
|
||||
"created_by__password__icontains=passw3rd or search=foo",
|
||||
]
|
||||
for f in filters:
|
||||
params = "?host_filter=%s" % urllib.parse.quote(f, safe='')
|
||||
response = get(url + params, admin_user)
|
||||
assert response.status_code == 400, f"Expected 400 for filter: {f}"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_unicode_host_filter(get, admin_user):
|
||||
org = Organization.objects.create(name="unicode-org")
|
||||
inv = Inventory.objects.create(name="unicode-inv", organization=org)
|
||||
host = Host.objects.create(name="ホスト", inventory=inv)
|
||||
group = Group.objects.create(name="グループ", inventory=inv)
|
||||
group.hosts.add(host)
|
||||
|
||||
response = host_filter_get(get, admin_user, "name=ホスト")
|
||||
assert response.status_code == 200
|
||||
assert len(response.data['results']) == 1
|
||||
assert response.data['results'][0]['id'] == host.id
|
||||
|
||||
response = host_filter_get(get, admin_user, "groups__name=グループ")
|
||||
assert response.status_code == 200
|
||||
assert len(response.data['results']) == 1
|
||||
assert response.data['results'][0]['id'] == host.id
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_filter",
|
||||
["string_without_equals", "1", "1.0", "true"],
|
||||
ids=["bare_string", "integer", "float", "bool"],
|
||||
)
|
||||
def test_invalid_host_filter(get, admin_user, invalid_filter):
|
||||
url = reverse('api:host_list')
|
||||
params = "?host_filter=%s" % urllib.parse.quote(invalid_filter, safe='')
|
||||
response = get(url + params, admin_user)
|
||||
assert response.status_code == 400
|
||||
|
||||
@@ -7,7 +7,7 @@ from django.core.exceptions import ValidationError
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
|
||||
from awx.main.models import InventorySource, Inventory, ActivityStream
|
||||
from awx.main.models import InventorySource, Inventory, ActivityStream, Organization
|
||||
from awx.main.utils.inventory_vars import update_group_variables
|
||||
|
||||
|
||||
@@ -963,3 +963,45 @@ class TestInventoryAllVariables:
|
||||
# Test step 6: Value of var x from source A reappears, because the
|
||||
# latest update from source B did not contain var x.
|
||||
self.update_and_verify(inv_src_c, {}, expect={"x": 1}, teststep=6)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_inventory_names_unique_per_organization(post, admin_user):
|
||||
"""Validate that two inventories can have the same name if they belong to different organizations."""
|
||||
org1 = Organization.objects.create(name='org-inv-1')
|
||||
org2 = Organization.objects.create(name='org-inv-2')
|
||||
inv_name = 'SharedInventoryName'
|
||||
|
||||
# Create inventory with same name in org1
|
||||
resp1 = post(
|
||||
reverse('api:inventory_list'),
|
||||
{'name': inv_name, 'organization': org1.id},
|
||||
admin_user,
|
||||
expect=201,
|
||||
)
|
||||
inv1_id = resp1.data['id']
|
||||
|
||||
# Create inventory with same name in org2 - should succeed
|
||||
resp2 = post(
|
||||
reverse('api:inventory_list'),
|
||||
{'name': inv_name, 'organization': org2.id},
|
||||
admin_user,
|
||||
expect=201,
|
||||
)
|
||||
inv2_id = resp2.data['id']
|
||||
|
||||
assert inv1_id != inv2_id
|
||||
inv1 = Inventory.objects.get(id=inv1_id)
|
||||
inv2 = Inventory.objects.get(id=inv2_id)
|
||||
assert inv1.name == inv2.name == inv_name
|
||||
assert inv1.organization.id == org1.id
|
||||
assert inv2.organization.id == org2.id
|
||||
|
||||
# Attempt to create another inventory with same name in org1 - should fail
|
||||
resp3 = post(
|
||||
reverse('api:inventory_list'),
|
||||
{'name': inv_name, 'organization': org1.id},
|
||||
admin_user,
|
||||
expect=400,
|
||||
)
|
||||
assert 'Inventory with this Name and Organization already exists' in json.dumps(resp3.data)
|
||||
|
||||
92
awx/main/tests/functional/api/test_notification_templates.py
Normal file
92
awx/main/tests/functional/api/test_notification_templates.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import pytest
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.models import NotificationTemplate, Organization
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_notification_template_names_unique_per_organization(post, admin_user):
|
||||
"""
|
||||
Validate that notification templates must have unique names within an organization,
|
||||
but can have the same name across different organizations.
|
||||
"""
|
||||
org1 = Organization.objects.create(name='org-notif-1')
|
||||
org2 = Organization.objects.create(name='org-notif-2')
|
||||
template_name = 'SharedNotificationName'
|
||||
|
||||
# Create notification template in org1
|
||||
resp1 = post(
|
||||
reverse('api:notification_template_list'),
|
||||
{
|
||||
'name': template_name,
|
||||
'organization': org1.id,
|
||||
'notification_type': 'email',
|
||||
'notification_configuration': {
|
||||
'username': 'user@example.com',
|
||||
'password': 'pass',
|
||||
'sender': 'sender@example.com',
|
||||
'recipients': ['recipient@example.com'],
|
||||
'host': 'smtp.example.com',
|
||||
'port': 25,
|
||||
'use_tls': False,
|
||||
'use_ssl': False,
|
||||
},
|
||||
},
|
||||
admin_user,
|
||||
expect=201,
|
||||
)
|
||||
template1_id = resp1.data['id']
|
||||
|
||||
# Create notification template with same name in org2 - should succeed
|
||||
resp2 = post(
|
||||
reverse('api:notification_template_list'),
|
||||
{
|
||||
'name': template_name,
|
||||
'organization': org2.id,
|
||||
'notification_type': 'email',
|
||||
'notification_configuration': {
|
||||
'username': 'user@example.com',
|
||||
'password': 'pass',
|
||||
'sender': 'sender@example.com',
|
||||
'recipients': ['recipient@example.com'],
|
||||
'host': 'smtp.example.com',
|
||||
'port': 25,
|
||||
'use_tls': False,
|
||||
'use_ssl': False,
|
||||
},
|
||||
},
|
||||
admin_user,
|
||||
expect=201,
|
||||
)
|
||||
template2_id = resp2.data['id']
|
||||
|
||||
assert template1_id != template2_id
|
||||
template1 = NotificationTemplate.objects.get(id=template1_id)
|
||||
template2 = NotificationTemplate.objects.get(id=template2_id)
|
||||
assert template1.name == template2.name == template_name
|
||||
assert template1.organization.id == org1.id
|
||||
assert template2.organization.id == org2.id
|
||||
|
||||
# Attempt to create another notification template with same name in org1 - should fail
|
||||
resp3 = post(
|
||||
reverse('api:notification_template_list'),
|
||||
{
|
||||
'name': template_name,
|
||||
'organization': org1.id,
|
||||
'notification_type': 'email',
|
||||
'notification_configuration': {
|
||||
'username': 'user@example.com',
|
||||
'password': 'pass',
|
||||
'sender': 'sender@example.com',
|
||||
'recipients': ['recipient@example.com'],
|
||||
'host': 'smtp.example.com',
|
||||
'port': 25,
|
||||
'use_tls': False,
|
||||
'use_ssl': False,
|
||||
},
|
||||
},
|
||||
admin_user,
|
||||
expect=400,
|
||||
)
|
||||
assert 'Notification template with this Organization and Name already exists' in str(resp3.data)
|
||||
311
awx/main/tests/functional/api/test_oidc_credential_test.py
Normal file
311
awx/main/tests/functional/api/test_oidc_credential_test.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Tests for OIDC workload identity credential test endpoints.
|
||||
|
||||
Tests the /api/v2/credentials/<id>/test/ and /api/v2/credential_types/<id>/test/
|
||||
endpoints when used with OIDC-enabled credential types.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
from awx.main.models import Credential, CredentialType, JobTemplate
|
||||
from awx.api.versioning import reverse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_template(organization, project):
|
||||
"""Job template with organization and project for OIDC JWT generation."""
|
||||
return JobTemplate.objects.create(name='test-jt', organization=organization, project=project, playbook='helloworld.yml')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_credentialtype():
|
||||
"""Create a credential type with workload_identity_token internal field."""
|
||||
oidc_type_inputs = {
|
||||
'fields': [
|
||||
{'id': 'url', 'label': 'Vault URL', 'type': 'string', 'help_text': 'The Vault server URL.'},
|
||||
{'id': 'auth_path', 'label': 'Auth Path', 'type': 'string', 'help_text': 'JWT auth mount path.'},
|
||||
{'id': 'role_id', 'label': 'Role ID', 'type': 'string', 'help_text': 'Vault role.'},
|
||||
{'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'secret': True, 'internal': True},
|
||||
],
|
||||
'metadata': [
|
||||
{'id': 'secret_path', 'label': 'Secret Path', 'type': 'string'},
|
||||
{'id': 'job_template_id', 'label': 'Job Template ID', 'type': 'string'},
|
||||
],
|
||||
'required': ['url', 'auth_path', 'role_id'],
|
||||
}
|
||||
|
||||
class MockPlugin(object):
|
||||
def backend(self, **kwargs):
|
||||
# Simulate successful backend call
|
||||
return 'secret'
|
||||
|
||||
with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin:
|
||||
mock_plugin.return_value = MockPlugin()
|
||||
oidc_type = CredentialType(kind='external', managed=True, namespace='hashivault-kv-oidc', name='HashiCorp Vault KV (OIDC)', inputs=oidc_type_inputs)
|
||||
oidc_type.save()
|
||||
yield oidc_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_credential(oidc_credentialtype):
|
||||
"""Create a credential using the OIDC credential type."""
|
||||
return Credential.objects.create(
|
||||
credential_type=oidc_credentialtype,
|
||||
name='oidc-vault-cred',
|
||||
inputs={'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role'},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oidc_backend():
|
||||
"""Fixture that mocks OIDC JWT generation and credential backend."""
|
||||
with mock.patch('awx.api.views.retrieve_workload_identity_jwt_with_claims') as mock_jwt, mock.patch('awx.api.views._jwt_decode') as mock_decode, mock.patch(
|
||||
'awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock
|
||||
) as mock_plugin:
|
||||
|
||||
# Set default return values
|
||||
mock_jwt.return_value = 'fake.jwt.token'
|
||||
mock_decode.return_value = {'iss': 'http://gateway/o', 'aud': 'vault'}
|
||||
|
||||
# Create mock backend
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_backend.backend.return_value = 'secret'
|
||||
mock_plugin.return_value = mock_backend
|
||||
|
||||
# Yield all mocks for test customization
|
||||
yield {
|
||||
'jwt': mock_jwt,
|
||||
'decode': mock_decode,
|
||||
'plugin': mock_plugin,
|
||||
'backend': mock_backend,
|
||||
}
|
||||
|
||||
|
||||
# --- Tests for CredentialExternalTest endpoint ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False)
|
||||
def test_credential_test_without_oidc_feature_flag(post, admin, oidc_credential):
|
||||
"""Test that credential test works without OIDC feature flag enabled."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': '1'}}
|
||||
|
||||
with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin:
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_backend.backend.return_value = 'secret'
|
||||
mock_plugin.return_value = mock_backend
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
# Should not contain JWT payload when feature flag is disabled
|
||||
assert 'details' not in response.data or 'sent_jwt_payload' not in response.data.get('details', {})
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
@pytest.mark.parametrize(
|
||||
'job_template_id, expected_error',
|
||||
[
|
||||
(None, 'Job template ID is required'),
|
||||
('not-an-integer', 'must be an integer'),
|
||||
('99999', 'does not exist'),
|
||||
],
|
||||
ids=['missing_job_template_id', 'invalid_job_template_id_type', 'nonexistent_job_template_id'],
|
||||
)
|
||||
def test_credential_test_job_template_validation(mock_flag, post, admin, oidc_credential, job_template_id, expected_error):
|
||||
"""Test that invalid job_template_id values return 400 with appropriate error messages."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret'}}
|
||||
if job_template_id is not None:
|
||||
data['metadata']['job_template_id'] = job_template_id
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
assert 'error_message' in response.data['details']
|
||||
assert expected_error in response.data['details']['error_message']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_no_access_to_job_template(mock_flag, post, alice, oidc_credential, job_template):
|
||||
"""Test that user without access to job template gets 403."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
# Give alice use permission on credential but not on job template
|
||||
oidc_credential.use_role.members.add(alice)
|
||||
|
||||
response = post(url, data, alice)
|
||||
assert response.status_code == 403
|
||||
assert 'You do not have access to job template' in str(response.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""Test that successful test returns JWT payload in response."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
# Customize mock for this test
|
||||
mock_oidc_backend['decode'].return_value = {
|
||||
'iss': 'http://gateway/o',
|
||||
'sub': 'system:serviceaccount:default:awx-operator',
|
||||
'aud': 'vault',
|
||||
'job_template_id': job_template.id,
|
||||
}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert response.data['details']['sent_jwt_payload']['job_template_id'] == job_template.id
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_response_does_not_contain_secret_value(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""
|
||||
the OIDC credential test endpoint must not echo the resolved Vault secret back to the caller.
|
||||
"""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
credential_secret_value = 'CREDENTIAL_SECRET'
|
||||
mock_oidc_backend['backend'].backend.return_value = credential_secret_value
|
||||
|
||||
response = post(url, data, admin)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert 'secret_value' not in response.data['details']
|
||||
assert credential_secret_value not in str(response.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_backend_failure_returns_jwt_and_error(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""Test that backend failure still returns JWT payload along with error message."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
# Make backend fail
|
||||
mock_oidc_backend['backend'].backend.side_effect = RuntimeError('Connection failed')
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
# Both JWT payload and error message should be present
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert 'error_message' in response.data['details']
|
||||
assert 'Connection failed' in response.data['details']['error_message']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_jwt_generation_failure(mock_flag, post, admin, oidc_credential, job_template):
|
||||
"""Test that JWT generation failure returns error without JWT payload."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
with mock.patch('awx.api.views.OIDCCredentialTestMixin._get_workload_identity_token') as mock_jwt:
|
||||
mock_jwt.side_effect = RuntimeError('Failed to generate JWT')
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
assert 'error_message' in response.data['details']
|
||||
assert 'Failed to generate JWT' in response.data['details']['error_message']
|
||||
# No JWT payload when generation fails
|
||||
assert 'sent_jwt_payload' not in response.data['details']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_test_job_template_id_not_passed_to_backend(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend):
|
||||
"""Test that job_template_id is removed from backend_kwargs."""
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk})
|
||||
data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
|
||||
# Check that backend was called without job_template_id but with url and workload_identity_token
|
||||
call_kwargs = mock_oidc_backend['backend'].backend.call_args[1]
|
||||
assert 'job_template_id' not in call_kwargs
|
||||
assert 'url' in call_kwargs
|
||||
assert 'workload_identity_token' in call_kwargs
|
||||
|
||||
|
||||
# --- Tests for CredentialTypeExternalTest endpoint ---
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_type_test_response_does_not_contain_secret_value(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend):
|
||||
"""
|
||||
the credential-type variant of the test endpoint should not return the secret value
|
||||
"""
|
||||
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
|
||||
data = {
|
||||
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'},
|
||||
'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)},
|
||||
}
|
||||
|
||||
credential_type_seret_value = 'CREDENTIAL_TYPE_SECRET'
|
||||
mock_oidc_backend['backend'].backend.return_value = credential_type_seret_value
|
||||
response = post(url, data, admin)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
assert 'secret_value' not in response.data['details']
|
||||
assert credential_type_seret_value not in str(response.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_type_test_missing_job_template_id(mock_flag, post, admin, oidc_credentialtype):
|
||||
"""Test that missing job_template_id returns 400 for credential type test endpoint."""
|
||||
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
|
||||
data = {
|
||||
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role'},
|
||||
'metadata': {'secret_path': 'test/secret'},
|
||||
}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'details' in response.data
|
||||
assert 'error_message' in response.data['details']
|
||||
assert 'Job template ID is required' in response.data['details']['error_message']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch('awx.api.views.flag_enabled', return_value=True)
|
||||
def test_credential_type_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend):
|
||||
"""Test that successful credential type test returns JWT payload."""
|
||||
url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk})
|
||||
data = {
|
||||
'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role'},
|
||||
'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)},
|
||||
}
|
||||
|
||||
response = post(url, data, admin)
|
||||
assert response.status_code == 202
|
||||
assert 'details' in response.data
|
||||
assert 'sent_jwt_payload' in response.data['details']
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_credential_external_test_returns_400_for_non_external_credential(post, admin, credential):
|
||||
# credential fixture creates a non-external credential (e.g. SSH/vault kind)
|
||||
url = reverse('api:credential_external_test', kwargs={'pk': credential.pk})
|
||||
response = post(url, {'metadata': {}}, admin)
|
||||
assert response.status_code == 400
|
||||
assert 'not testable' in response.data.get('detail', '').lower()
|
||||
@@ -139,6 +139,7 @@ def test_survey_password_default(post, patch, admin_user, project, inventory, su
|
||||
("DTSTART:20300308T050000Z", "One or more rule required in rrule"),
|
||||
("DTSTART:20300308T050000Z RRULE:FREQ=MONTHLY;INTERVAL=1; EXDATE:20220401", "EXDATE not allowed in rrule"),
|
||||
("DTSTART:20300308T050000Z RRULE:FREQ=MONTHLY;INTERVAL=1; RDATE:20220401", "RDATE not allowed in rrule"),
|
||||
("DTSTART:20300308T050000Z RRULE:FREQ=YEARLY;INTERVAL=0;BYDAY=MO", "INTERVAL must be a positive integer"),
|
||||
("DTSTART:20300308T050000Z RRULE:FREQ=SECONDLY;INTERVAL=5;COUNT=6", "SECONDLY is not supported"),
|
||||
# Individual rule test
|
||||
("DTSTART:20300308T050000Z RRULE:NONSENSE", "INTERVAL required in rrule"),
|
||||
@@ -202,6 +203,7 @@ def test_multiple_invalid_rrules(post, admin_user, project, inventory):
|
||||
"rrule": [
|
||||
"Multiple DTSTART is not supported.",
|
||||
"INTERVAL required in rrule: RULE:FREQ=SECONDLY",
|
||||
"SECONDLY is not supported: RULE:FREQ=SECONDLY",
|
||||
"RRULE may not contain both COUNT and UNTIL: RULE:FREQ=MINUTELY;INTERVAL=10;COUNT=5;UNTIL=20220101",
|
||||
"rrule parsing failed validation: 'NoneType' object has no attribute 'group'",
|
||||
]
|
||||
|
||||
191
awx/main/tests/functional/api/test_smart_inventory.py
Normal file
191
awx/main/tests/functional/api/test_smart_inventory.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.models import Organization, Host, Group, Inventory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smart_inv_org():
|
||||
return Organization.objects.create(name="smart-org")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smart_inv_source(smart_inv_org):
|
||||
inv = Inventory.objects.create(name="smart-source-inv", organization=smart_inv_org)
|
||||
Host.objects.create(name="hostA", inventory=inv)
|
||||
Host.objects.create(name="hostB", inventory=inv)
|
||||
Host.objects.create(name="hostDup", inventory=inv)
|
||||
groupA = Group.objects.create(name="groupA", inventory=inv)
|
||||
groupB = Group.objects.create(name="groupB", inventory=inv)
|
||||
groupA.hosts.add(*inv.hosts.filter(name__in=["hostA", "hostDup"]))
|
||||
groupB.hosts.add(*inv.hosts.filter(name__in=["hostB", "hostDup"]))
|
||||
return inv
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_create_smart_inventory(post, admin_user, smart_inv_org):
|
||||
resp = post(
|
||||
reverse('api:inventory_list'),
|
||||
{
|
||||
'name': 'my-smart-inv',
|
||||
'kind': 'smart',
|
||||
'organization': smart_inv_org.pk,
|
||||
'host_filter': 'name=hostA',
|
||||
},
|
||||
admin_user,
|
||||
expect=201,
|
||||
)
|
||||
assert resp.data['kind'] == 'smart'
|
||||
assert resp.data['host_filter'] == 'name=hostA'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_create_smart_inventory_requires_host_filter(post, admin_user, smart_inv_org):
|
||||
resp = post(
|
||||
reverse('api:inventory_list'),
|
||||
{
|
||||
'name': 'no-filter-smart',
|
||||
'kind': 'smart',
|
||||
'organization': smart_inv_org.pk,
|
||||
},
|
||||
admin_user,
|
||||
expect=400,
|
||||
)
|
||||
assert 'host_filter' in json.dumps(resp.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_unable_to_create_host_in_smart_inventory(post, admin_user, smart_inv_org):
|
||||
smart_inv = Inventory.objects.create(
|
||||
name="no-host-create",
|
||||
kind="smart",
|
||||
host_filter="name=hostA",
|
||||
organization=smart_inv_org,
|
||||
)
|
||||
url = reverse('api:inventory_hosts_list', kwargs={'pk': smart_inv.pk})
|
||||
resp = post(url, {'name': 'new-host'}, admin_user, expect=400)
|
||||
assert 'Cannot create' in json.dumps(resp.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_unable_to_create_group_in_smart_inventory(post, admin_user, smart_inv_org):
|
||||
smart_inv = Inventory.objects.create(
|
||||
name="no-group-create",
|
||||
kind="smart",
|
||||
host_filter="name=hostA",
|
||||
organization=smart_inv_org,
|
||||
)
|
||||
url = reverse('api:inventory_groups_list', kwargs={'pk': smart_inv.pk})
|
||||
resp = post(url, {'name': 'new-group'}, admin_user, expect=400)
|
||||
assert 'Cannot create' in json.dumps(resp.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_unable_to_create_inventory_source_in_smart_inventory(post, admin_user, smart_inv_org):
|
||||
smart_inv = Inventory.objects.create(
|
||||
name="no-src-create",
|
||||
kind="smart",
|
||||
host_filter="name=hostA",
|
||||
organization=smart_inv_org,
|
||||
)
|
||||
url = reverse('api:inventory_inventory_sources_list', kwargs={'pk': smart_inv.pk})
|
||||
resp = post(url, {'name': 'new-src', 'source': 'ec2'}, admin_user, expect=400)
|
||||
assert 'Cannot create' in json.dumps(resp.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_convert_smart_to_regular_inventory(admin_user, smart_inv_org):
|
||||
smart_inv = Inventory.objects.create(
|
||||
name="convert-to-regular",
|
||||
kind="smart",
|
||||
host_filter="name=anything",
|
||||
organization=smart_inv_org,
|
||||
)
|
||||
assert smart_inv.kind == 'smart'
|
||||
smart_inv.host_filter = ''
|
||||
smart_inv.kind = ''
|
||||
smart_inv.save()
|
||||
smart_inv.refresh_from_db()
|
||||
assert smart_inv.kind == ''
|
||||
assert not smart_inv.host_filter
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_smart_inventory_deletion_does_not_cascade(admin_user, smart_inv_source, smart_inv_org):
|
||||
host = smart_inv_source.hosts.first()
|
||||
smart_inv = Inventory.objects.create(
|
||||
name="delete-no-cascade",
|
||||
kind="smart",
|
||||
host_filter="name=%s" % host.name,
|
||||
organization=smart_inv_org,
|
||||
)
|
||||
smart_inv.delete()
|
||||
assert Host.objects.filter(pk=host.pk).exists()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_urlencode_host_filter(post, admin_user, smart_inv_org):
|
||||
post(
|
||||
reverse('api:inventory_list'),
|
||||
data={
|
||||
'name': 'url-encoded-smart',
|
||||
'kind': 'smart',
|
||||
'organization': smart_inv_org.pk,
|
||||
'host_filter': 'ansible_facts__ansible_distribution_version=%227.4%22',
|
||||
},
|
||||
user=admin_user,
|
||||
expect=201,
|
||||
)
|
||||
si = Inventory.objects.get(name='url-encoded-smart')
|
||||
assert si.host_filter == 'ansible_facts__ansible_distribution_version="7.4"'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_host_filter_unicode(post, admin_user, smart_inv_org):
|
||||
post(
|
||||
reverse('api:inventory_list'),
|
||||
data={
|
||||
'name': 'unicode-smart',
|
||||
'kind': 'smart',
|
||||
'organization': smart_inv_org.pk,
|
||||
'host_filter': u'ansible_facts__ansible_distribution=レッドハット',
|
||||
},
|
||||
user=admin_user,
|
||||
expect=201,
|
||||
)
|
||||
si = Inventory.objects.get(name='unicode-smart')
|
||||
assert si.host_filter == u'ansible_facts__ansible_distribution=レッドハット'
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize("lookup", ['icontains', 'has_keys'])
|
||||
def test_host_filter_invalid_ansible_facts_lookup(post, admin_user, smart_inv_org, lookup):
|
||||
resp = post(
|
||||
reverse('api:inventory_list'),
|
||||
data={
|
||||
'name': 'invalid-lookup-smart',
|
||||
'kind': 'smart',
|
||||
'organization': smart_inv_org.pk,
|
||||
'host_filter': u'ansible_facts__ansible_distribution__{}=cent'.format(lookup),
|
||||
},
|
||||
user=admin_user,
|
||||
expect=400,
|
||||
)
|
||||
assert 'ansible_facts does not support searching with __{}'.format(lookup) in json.dumps(resp.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_host_filter_ansible_facts_exact(post, admin_user, smart_inv_org):
|
||||
post(
|
||||
reverse('api:inventory_list'),
|
||||
data={
|
||||
'name': 'exact-smart',
|
||||
'kind': 'smart',
|
||||
'organization': smart_inv_org.pk,
|
||||
'host_filter': 'ansible_facts__ansible_distribution__exact="CentOS"',
|
||||
},
|
||||
user=admin_user,
|
||||
expect=201,
|
||||
)
|
||||
@@ -13,6 +13,7 @@ from awx.main.models.workflow import (
|
||||
WorkflowJobTemplateNode,
|
||||
)
|
||||
from awx.main.models.credential import Credential
|
||||
from awx.main.models.label import Label
|
||||
from awx.main.scheduler import TaskManager, WorkflowManager, DependencyManager
|
||||
|
||||
# Django
|
||||
@@ -51,6 +52,31 @@ def test_node_accepts_prompted_fields(inventory, project, workflow_job_template,
|
||||
post(url, {'unified_job_template': job_template.pk, 'limit': 'webservers'}, user=admin_user, expect=201)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_node_extra_data_patch_with_unprompted_labels(inventory, project, organization, workflow_job_template, patch, admin_user):
|
||||
"""AAP-41742: PATCH extra_data on a workflow node should succeed even when
|
||||
the node has labels associated but the JT has ask_labels_on_launch=False."""
|
||||
jt = JobTemplate.objects.create(
|
||||
inventory=inventory,
|
||||
project=project,
|
||||
playbook='helloworld.yml',
|
||||
ask_variables_on_launch=True,
|
||||
ask_labels_on_launch=False,
|
||||
)
|
||||
label = Label.objects.create(name='repro-label', organization=organization)
|
||||
|
||||
node = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=workflow_job_template,
|
||||
unified_job_template=jt,
|
||||
extra_data={'foo': 'bar'},
|
||||
)
|
||||
node.labels.add(label)
|
||||
|
||||
url = reverse('api:workflow_job_template_node_detail', kwargs={'pk': node.pk})
|
||||
r = patch(url, {'extra_data': {'foo': 'edited'}}, user=admin_user, expect=200)
|
||||
assert r.data['extra_data'] == {'foo': 'edited'}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize(
|
||||
"field_name, field_value",
|
||||
|
||||
@@ -131,14 +131,18 @@ def test_workflow_creation_permissions(setup_managed_roles, organization, workfl
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_assign_credential_to_user_of_another_org(setup_managed_roles, credential, admin_user, rando, org_admin, organization, post):
|
||||
'''Test that a credential can only be assigned to a user in the same organization'''
|
||||
# cannot assign credential to rando, as rando is not in the same org as the credential
|
||||
'''Test that a credential can only be assigned to a user in the same organization by non-superusers'''
|
||||
rd = RoleDefinition.objects.get(name="Credential Admin")
|
||||
credential.organization = organization
|
||||
credential.save(update_fields=['organization'])
|
||||
assert credential.organization not in Organization.access_qs(rando, 'member')
|
||||
url = django_reverse('roleuserassignment-list')
|
||||
resp = post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=400)
|
||||
|
||||
# superuser can assign cross-org
|
||||
post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=201)
|
||||
|
||||
# non-superuser (org_admin) cannot assign cross-org
|
||||
resp = post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=org_admin, expect=400)
|
||||
assert "You cannot grant credential access to a User not in the credentials' organization" in str(resp.data)
|
||||
|
||||
# can assign credential to superuser
|
||||
@@ -146,7 +150,7 @@ def test_assign_credential_to_user_of_another_org(setup_managed_roles, credentia
|
||||
rando.save()
|
||||
post(url=url, data={"user": rando.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=201)
|
||||
|
||||
# can assign credential to org_admin
|
||||
# can assign credential to org_admin (same org)
|
||||
assert credential.organization in Organization.access_qs(org_admin, 'member')
|
||||
post(url=url, data={"user": org_admin.id, "role_definition": rd.id, "object_id": credential.id}, user=admin_user, expect=201)
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Tests for AAP-68023: host_list_rbac performance optimization.
|
||||
|
||||
The host list endpoint fetches the large ansible_facts JSON column
|
||||
unnecessarily. The HostManager now defers it by default so that
|
||||
list queries avoid transferring this data from PostgreSQL.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from awx.main.models import Host
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AAP-68023: Verify ansible_facts column is deferred by HostManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestHostManagerDeferral:
|
||||
"""AAP-68023: The host list fetches 200+ columns unnecessarily.
|
||||
|
||||
The ansible_facts JSON column is large and not used by the list
|
||||
serializer. HostManager.get_queryset() must defer it so that
|
||||
every query through Host.objects avoids fetching it by default.
|
||||
"""
|
||||
|
||||
def test_ansible_facts_deferred_by_default(self):
|
||||
"""ansible_facts should be in the deferred set for default Host queries."""
|
||||
qs = Host.objects.all()
|
||||
deferred = qs.query.deferred_loading[0]
|
||||
assert 'ansible_facts' in deferred, f'ansible_facts should be deferred by the HostManager. ' f'Deferred fields: {deferred}'
|
||||
|
||||
def test_ansible_facts_accessible_when_needed(self, inventory):
|
||||
"""Deferred fields are still accessible — Django fetches on access."""
|
||||
host = Host.objects.create(
|
||||
name='facts-host',
|
||||
inventory=inventory,
|
||||
ansible_facts={'os': 'linux'},
|
||||
)
|
||||
loaded = Host.objects.get(pk=host.pk)
|
||||
assert loaded.ansible_facts == {'os': 'linux'}
|
||||
240
awx/main/tests/functional/dab_rbac/test_notification_rbac.py
Normal file
240
awx/main/tests/functional/dab_rbac/test_notification_rbac.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import pytest
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.models import NotificationTemplate, Organization
|
||||
|
||||
from ansible_base.rbac.models import RoleDefinition
|
||||
from ansible_base.rbac import permission_registry
|
||||
|
||||
NT_DATA = {
|
||||
'notification_type': 'webhook',
|
||||
'notification_configuration': {
|
||||
'url': 'http://localhost',
|
||||
'username': '',
|
||||
'password': '',
|
||||
'headers': {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def nt_url(pk):
|
||||
return reverse('api:notification_template_detail', kwargs={'pk': pk})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nt_add_role(setup_managed_roles):
|
||||
"""A custom role with only add_notificationtemplate and view_organization.
|
||||
This is intentionally narrower than Organization NotificationTemplate Admin
|
||||
so that give_creator_permissions actually creates creator permissions."""
|
||||
rd, _ = RoleDefinition.objects.get_or_create(
|
||||
name='nt-add-only',
|
||||
permissions=['add_notificationtemplate', 'view_organization'],
|
||||
content_type=permission_registry.content_type_model.objects.get_for_model(Organization),
|
||||
)
|
||||
return rd
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_create_with_add_only_role_gets_creator_permissions(rando, organization, post, get, patch, nt_add_role):
|
||||
"""User with only add permission creates a notification template and gets
|
||||
creator permissions (change, delete, view) via give_creator_permissions.
|
||||
This exercises the fix for models without old-style roles (AAP-57274)."""
|
||||
nt_add_role.give_permission(rando, organization)
|
||||
|
||||
r = post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='rando-nt', organization=organization.id, **NT_DATA),
|
||||
user=rando,
|
||||
expect=201,
|
||||
)
|
||||
nt = NotificationTemplate.objects.get(pk=r.data['id'])
|
||||
assert rando.has_obj_perm(nt, 'change')
|
||||
assert rando.has_obj_perm(nt, 'view')
|
||||
|
||||
# Creator permissions survive revocation of the org-level add role
|
||||
nt_add_role.remove_permission(rando, organization)
|
||||
get(nt_url(nt.pk), user=rando, expect=200)
|
||||
patch(nt_url(nt.pk), data={'description': 'updated'}, user=rando, expect=200)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_org_admin_can_crud(rando, organization, post, get, patch, delete, setup_managed_roles):
|
||||
"""User with org-level notification admin can create, view, edit, and delete"""
|
||||
rd = RoleDefinition.objects.get(name='Organization NotificationTemplate Admin')
|
||||
rd.give_permission(rando, organization)
|
||||
|
||||
r = post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='rando-nt', organization=organization.id, **NT_DATA),
|
||||
user=rando,
|
||||
expect=201,
|
||||
)
|
||||
pk = r.data['id']
|
||||
url = nt_url(pk)
|
||||
|
||||
get(url, user=rando, expect=200)
|
||||
patch(url, data={'description': 'updated'}, user=rando, expect=200)
|
||||
delete(url, user=rando, expect=204)
|
||||
assert not NotificationTemplate.objects.filter(pk=pk).exists()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_unpermissioned_user_cannot_access(rando, notification_template, get, patch, delete, setup_managed_roles):
|
||||
"""User without any permissions cannot view, edit, or delete a notification template"""
|
||||
url = nt_url(notification_template.pk)
|
||||
|
||||
get(url, user=rando, expect=403)
|
||||
patch(url, data={'description': 'nope'}, user=rando, expect=403)
|
||||
delete(url, user=rando, expect=403)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_grant_and_revoke_object_role(rando, notification_template, get, patch, setup_managed_roles):
|
||||
"""Granting and revoking NotificationTemplate Admin role controls access"""
|
||||
rd = RoleDefinition.objects.get(name='NotificationTemplate Admin')
|
||||
url = nt_url(notification_template.pk)
|
||||
|
||||
get(url, user=rando, expect=403)
|
||||
|
||||
rd.give_permission(rando, notification_template)
|
||||
get(url, user=rando, expect=200)
|
||||
patch(url, data={'description': 'changed'}, user=rando, expect=200)
|
||||
|
||||
rd.remove_permission(rando, notification_template)
|
||||
get(url, user=rando, expect=403)
|
||||
patch(url, data={'description': 'nope'}, user=rando, expect=403)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_creator_can_access_sub_endpoints(rando, organization, post, get, nt_add_role):
|
||||
"""Creator can access notification list sub-endpoint"""
|
||||
nt_add_role.give_permission(rando, organization)
|
||||
|
||||
r = post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='rando-nt', organization=organization.id, **NT_DATA),
|
||||
user=rando,
|
||||
expect=201,
|
||||
)
|
||||
pk = r.data['id']
|
||||
|
||||
# Revoke org-level role so only creator permissions remain
|
||||
nt_add_role.remove_permission(rando, organization)
|
||||
|
||||
get(
|
||||
reverse('api:notification_template_notification_list', kwargs={'pk': pk}),
|
||||
user=rando,
|
||||
expect=200,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_list_filtered_by_permissions(rando, admin_user, organization, post, get, nt_add_role):
|
||||
"""Notification template list only shows templates the user has access to"""
|
||||
nt_add_role.give_permission(rando, organization)
|
||||
|
||||
post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='admin-nt', organization=organization.id, **NT_DATA),
|
||||
user=admin_user,
|
||||
expect=201,
|
||||
)
|
||||
post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='rando-nt', organization=organization.id, **NT_DATA),
|
||||
user=rando,
|
||||
expect=201,
|
||||
)
|
||||
|
||||
# rando has org-level add, but admin-nt was created by admin → rando shouldn't see it
|
||||
# unless org admin role also gives view. With add-only role, rando has view_organization
|
||||
# but not view_notificationtemplate at the org level, so they only see their own (via creator perms)
|
||||
nt_add_role.remove_permission(rando, organization)
|
||||
r = get(reverse('api:notification_template_list'), user=rando, expect=200)
|
||||
visible_names = {item['name'] for item in r.data['results']}
|
||||
assert 'rando-nt' in visible_names
|
||||
assert 'admin-nt' not in visible_names
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_creator_access_list_with_add_only_role(rando, organization, post, get, nt_add_role):
|
||||
"""User with add_only role creates a notification template and can access its access_list endpoint"""
|
||||
from ansible_base.rbac.models import DABContentType
|
||||
|
||||
nt_add_role.give_permission(rando, organization)
|
||||
|
||||
r = post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='rando-nt', organization=organization.id, **NT_DATA),
|
||||
user=rando,
|
||||
expect=201,
|
||||
)
|
||||
nt = NotificationTemplate.objects.get(pk=r.data['id'])
|
||||
|
||||
# Revoke org-level role so only creator permissions remain
|
||||
nt_add_role.remove_permission(rando, organization)
|
||||
|
||||
# Creator should be able to access the access_list endpoint for their own notification template
|
||||
# Use the DAB access_list endpoint pattern: /api/v2/role_user_access/{model_name}/{pk}/
|
||||
ct = DABContentType.objects.get_for_model(NotificationTemplate)
|
||||
access_list_url = f'/api/v2/role_user_access/{ct.api_slug}/{nt.pk}/?order_by=id'
|
||||
r = get(access_list_url, user=rando, expect=200)
|
||||
|
||||
# The creator should be listed in the access list
|
||||
usernames = {user['username'] for user in r.data['results']}
|
||||
assert rando.username in usernames
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_unpermissioned_user_cannot_access_access_list(rando, organization, post, admin_user, get, setup_managed_roles):
|
||||
"""User without view permission cannot access the access_list endpoint"""
|
||||
from ansible_base.rbac.models import DABContentType
|
||||
|
||||
# Create a notification template as admin
|
||||
r = post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='admin-nt', organization=organization.id, **NT_DATA),
|
||||
user=admin_user,
|
||||
expect=201,
|
||||
)
|
||||
nt = NotificationTemplate.objects.get(pk=r.data['id'])
|
||||
|
||||
ct = DABContentType.objects.get_for_model(NotificationTemplate)
|
||||
access_list_url = f'/api/v2/role_user_access/{ct.api_slug}/{nt.pk}/?order_by=id'
|
||||
# rando has no permissions on this notification template, so they can't see it or its access list
|
||||
# The endpoint returns 404 (not found) instead of 403 when user can't view the resource
|
||||
get(access_list_url, user=rando, expect=404)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_access_list_shows_creator(rando, organization, post, get, nt_add_role, setup_managed_roles):
|
||||
"""Access list shows the creator with direct permissions"""
|
||||
from ansible_base.rbac.models import DABContentType
|
||||
from ansible_base.rbac.models import RoleDefinition
|
||||
|
||||
nt_add_role.give_permission(rando, organization)
|
||||
|
||||
# rando creates a notification template
|
||||
r = post(
|
||||
reverse('api:notification_template_list'),
|
||||
dict(name='rando-nt', organization=organization.id, **NT_DATA),
|
||||
user=rando,
|
||||
expect=201,
|
||||
)
|
||||
nt = NotificationTemplate.objects.get(pk=r.data['id'])
|
||||
|
||||
# Now assign them the object admin role directly too
|
||||
rd = RoleDefinition.objects.get(name='NotificationTemplate Admin')
|
||||
rd.give_permission(rando, nt)
|
||||
|
||||
ct = DABContentType.objects.get_for_model(NotificationTemplate)
|
||||
access_list_url = f'/api/v2/role_user_access/{ct.api_slug}/{nt.pk}/?order_by=id'
|
||||
r = get(access_list_url, user=rando, expect=200)
|
||||
|
||||
# rando should be listed with direct permissions from both creator and object role assignment
|
||||
user_data = {item['username']: item for item in r.data['results']}
|
||||
assert rando.username in user_data
|
||||
|
||||
# Verify they have direct role assignments
|
||||
assert len(user_data[rando.username]['object_role_assignments']) > 0
|
||||
assert any(assign.get('type') == 'direct' for assign in user_data[rando.username]['object_role_assignments'])
|
||||
@@ -173,6 +173,22 @@ def test_creator_permission(rando, admin_user, inventory, setup_managed_roles):
|
||||
assert rando in inventory.admin_role.members.all()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_creator_permission_notification_template(rando, organization, setup_managed_roles):
|
||||
"""NotificationTemplate has no old-style roles, give_creator_permissions should not error"""
|
||||
from awx.main.models import NotificationTemplate
|
||||
|
||||
nt = NotificationTemplate.objects.create(
|
||||
name='test-nt',
|
||||
organization=organization,
|
||||
notification_type='slack',
|
||||
notification_configuration={'token': 'x', 'channels': ['#test']},
|
||||
)
|
||||
give_creator_permissions(rando, nt)
|
||||
assignment = RoleUserAssignment.objects.filter(user=rando, object_id=nt.pk).first()
|
||||
assert assignment is not None
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_implicit_parents_no_assignments(organization):
|
||||
"""Through the normal course of creating models, we should not be changing DAB RBAC permissions"""
|
||||
|
||||
@@ -8,7 +8,7 @@ from awx.main.management.commands.dispatcherd import _hash_config
|
||||
def test_dispatcherd_config_hash_is_stable(settings, monkeypatch):
|
||||
monkeypatch.setenv('AWX_COMPONENT', 'dispatcher')
|
||||
settings.CLUSTER_HOST_ID = 'test-node'
|
||||
settings.JOB_EVENT_WORKERS = 1
|
||||
settings.DISPATCHER_MIN_WORKERS = 1
|
||||
settings.DISPATCHER_SCHEDULE = {}
|
||||
|
||||
config_one = get_dispatcherd_config(for_service=True)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
|
||||
# AWX context managers for testing
|
||||
from awx.main.signals import disable_activity_stream, disable_computed_fields, update_inventory_computed_fields
|
||||
from awx.main.signals import disable_activity_stream, disable_computed_fields
|
||||
from awx.main.tasks.system import update_inventory_computed_fields
|
||||
|
||||
# AWX models
|
||||
from awx.main.models.organization import Organization
|
||||
|
||||
@@ -71,8 +71,10 @@ class TestEvents:
|
||||
assert s.skipped == 0
|
||||
|
||||
for host in Host.objects.all():
|
||||
assert host.last_job_id == self.job.id
|
||||
assert host.last_job_host_summary.host == host
|
||||
latest_summary = JobHostSummary.latest_for_host(host.id)
|
||||
assert latest_summary is not None
|
||||
assert latest_summary.job_id == self.job.id
|
||||
assert latest_summary.host == host
|
||||
|
||||
def test_host_summary_generation_with_deleted_hosts(self):
|
||||
self._generate_hosts(10)
|
||||
@@ -91,8 +93,7 @@ class TestEvents:
|
||||
def test_host_summary_generation_with_limit(self):
|
||||
# Make an inventory with 10 hosts, run a playbook with a --limit
|
||||
# pointed at *one* host,
|
||||
# Verify that *only* that host has an associated JobHostSummary and that
|
||||
# *only* that host has an updated value for .last_job.
|
||||
# Verify that *only* that host has an associated JobHostSummary.
|
||||
self._generate_hosts(10)
|
||||
|
||||
# by making the playbook_on_stats *only* include Host 1, we're emulating
|
||||
@@ -105,13 +106,14 @@ class TestEvents:
|
||||
# be related to the appropriate Host)
|
||||
assert JobHostSummary.objects.count() == 1
|
||||
for h in Host.objects.all():
|
||||
latest_summary = JobHostSummary.latest_for_host(h.id)
|
||||
if h.name == 'Host 1':
|
||||
assert h.last_job_id == self.job.id
|
||||
assert h.last_job_host_summary_id == JobHostSummary.objects.first().id
|
||||
assert latest_summary is not None
|
||||
assert latest_summary.job_id == self.job.id
|
||||
assert latest_summary.id == JobHostSummary.objects.first().id
|
||||
else:
|
||||
# all other hosts in the inventory should remain untouched
|
||||
assert h.last_job_id is None
|
||||
assert h.last_job_host_summary_id is None
|
||||
# all other hosts in the inventory should have no summary
|
||||
assert latest_summary is None
|
||||
|
||||
def test_host_metrics_insert(self):
|
||||
self._generate_hosts(10)
|
||||
|
||||
213
awx/main/tests/functional/models/test_host_queryset.py
Normal file
213
awx/main/tests/functional/models/test_host_queryset.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import pytest
|
||||
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
from django.db import connection
|
||||
from django.utils.timezone import now
|
||||
|
||||
from awx.main.models import Job, JobEvent, Inventory, Host, JobHostSummary
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestHostLatestSummaryQuerySet:
|
||||
"""Tests for HostLatestSummaryQuerySet and Host.latest_summary property."""
|
||||
|
||||
def _create_inventory_with_hosts(self, count=5):
|
||||
inventory = Inventory()
|
||||
inventory.save()
|
||||
Host.objects.bulk_create([Host(created=now(), modified=now(), name=f'host-{i}', inventory_id=inventory.id) for i in range(count)])
|
||||
return inventory
|
||||
|
||||
def _run_job(self, inventory, host_names=None):
|
||||
"""Run a fake job that creates JobHostSummary records for the given hosts."""
|
||||
if host_names is None:
|
||||
host_names = list(inventory.hosts.values_list('name', flat=True))
|
||||
job = Job(inventory=inventory)
|
||||
job.save()
|
||||
host_map = dict(inventory.hosts.values_list('name', 'id'))
|
||||
JobEvent.create_from_data(
|
||||
job_id=job.pk,
|
||||
parent_uuid='abc123',
|
||||
event='playbook_on_stats',
|
||||
event_data={
|
||||
'ok': {name: 1 for name in host_names},
|
||||
'changed': {},
|
||||
'dark': {},
|
||||
'failures': {},
|
||||
'ignored': {},
|
||||
'processed': {},
|
||||
'rescued': {},
|
||||
'skipped': {},
|
||||
},
|
||||
host_map=host_map,
|
||||
).save()
|
||||
return job
|
||||
|
||||
def test_with_latest_summary_id_annotates_hosts(self):
|
||||
inventory = self._create_inventory_with_hosts(3)
|
||||
job = self._run_job(inventory)
|
||||
|
||||
hosts = Host.objects.filter(inventory=inventory).with_latest_summary_id()
|
||||
for host in hosts:
|
||||
assert hasattr(host, '_latest_summary_id')
|
||||
summary = JobHostSummary.objects.filter(host=host, job=job).first()
|
||||
assert host._latest_summary_id == summary.id
|
||||
|
||||
def test_with_latest_summary_id_returns_most_recent(self):
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
self._run_job(inventory)
|
||||
job2 = self._run_job(inventory)
|
||||
|
||||
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
|
||||
latest = JobHostSummary.objects.filter(host_id=host.id).order_by('-id').first()
|
||||
assert latest.job_id == job2.id
|
||||
assert host._latest_summary_id == latest.id
|
||||
|
||||
def test_with_latest_summary_id_none_for_no_summaries(self):
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
# No job run — no summaries
|
||||
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
|
||||
assert host._latest_summary_id is None
|
||||
|
||||
def test_fetch_all_bulk_attaches_summaries(self):
|
||||
inventory = self._create_inventory_with_hosts(5)
|
||||
self._run_job(inventory)
|
||||
|
||||
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id())
|
||||
for host in hosts:
|
||||
assert hasattr(host, '_latest_summary_cache')
|
||||
assert host._latest_summary_cache is not None
|
||||
assert isinstance(host._latest_summary_cache, JobHostSummary)
|
||||
|
||||
def test_fetch_all_skips_non_annotated_querysets(self):
|
||||
"""Non-annotated querysets should NOT set _latest_summary_cache,
|
||||
preserving the per-object fallback in Host.latest_summary."""
|
||||
inventory = self._create_inventory_with_hosts(3)
|
||||
self._run_job(inventory)
|
||||
|
||||
hosts = list(Host.objects.filter(inventory=inventory))
|
||||
for host in hosts:
|
||||
assert not hasattr(host, '_latest_summary_cache')
|
||||
|
||||
def test_count_does_not_trigger_fetch_all(self):
|
||||
"""Calling .count() should not trigger _fetch_all or the bulk-attach logic."""
|
||||
inventory = self._create_inventory_with_hosts(5)
|
||||
self._run_job(inventory)
|
||||
|
||||
qs = Host.objects.filter(inventory=inventory).with_latest_summary_id()
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
result = qs.count()
|
||||
|
||||
assert result == 5
|
||||
# count() should produce a single COUNT query, not fetch all rows + summaries
|
||||
assert len(ctx.captured_queries) == 1
|
||||
assert 'COUNT' in ctx.captured_queries[0]['sql'].upper()
|
||||
|
||||
def test_exists_does_not_trigger_fetch_all(self):
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
self._run_job(inventory)
|
||||
|
||||
qs = Host.objects.filter(inventory=inventory).with_latest_summary_id()
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
result = qs.exists()
|
||||
|
||||
assert result is True
|
||||
assert len(ctx.captured_queries) == 1
|
||||
|
||||
def test_latest_summary_property_uses_cache(self):
|
||||
"""When loaded via with_latest_summary_id(), Host.latest_summary
|
||||
should use the bulk-attached cache without extra queries."""
|
||||
inventory = self._create_inventory_with_hosts(3)
|
||||
self._run_job(inventory)
|
||||
|
||||
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id())
|
||||
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
for host in hosts:
|
||||
summary = host.latest_summary
|
||||
assert summary is not None
|
||||
|
||||
# No additional queries — all data came from the bulk-attach
|
||||
assert len(ctx.captured_queries) == 0
|
||||
|
||||
def test_latest_summary_property_fallback(self):
|
||||
"""When loaded without annotation, Host.latest_summary should
|
||||
fall back to a per-object query."""
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
job = self._run_job(inventory)
|
||||
|
||||
host = Host.objects.filter(inventory=inventory).first()
|
||||
assert not hasattr(host, '_latest_summary_cache')
|
||||
|
||||
summary = host.latest_summary
|
||||
assert summary is not None
|
||||
assert summary.job_id == job.id
|
||||
# After first access, the cache should be populated
|
||||
assert hasattr(host, '_latest_summary_cache')
|
||||
|
||||
def test_latest_summary_none_when_no_summaries(self):
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
|
||||
assert host.latest_summary is None
|
||||
|
||||
def test_latest_job_property(self):
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
job = self._run_job(inventory)
|
||||
|
||||
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
|
||||
assert host.latest_job is not None
|
||||
assert host.latest_job.id == job.id
|
||||
|
||||
def test_latest_job_none_when_no_summaries(self):
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
host = Host.objects.filter(inventory=inventory).first()
|
||||
assert host.latest_job is None
|
||||
|
||||
def test_bulk_attach_select_related(self):
|
||||
"""The bulk-attach should select_related job and job__job_template
|
||||
so accessing them doesn't cause extra queries."""
|
||||
inventory = self._create_inventory_with_hosts(3)
|
||||
self._run_job(inventory)
|
||||
|
||||
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id())
|
||||
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
for host in hosts:
|
||||
summary = host.latest_summary
|
||||
_ = summary.job # should not query
|
||||
|
||||
assert len(ctx.captured_queries) == 0
|
||||
|
||||
def test_chaining_preserves_annotation(self):
|
||||
"""Chaining .filter() after .with_latest_summary_id() should
|
||||
preserve the annotation and bulk-attach behavior."""
|
||||
inventory = self._create_inventory_with_hosts(5)
|
||||
self._run_job(inventory)
|
||||
|
||||
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id().filter(name__startswith='host-').order_by('name'))
|
||||
assert len(hosts) == 5
|
||||
for host in hosts:
|
||||
assert hasattr(host, '_latest_summary_cache')
|
||||
assert host._latest_summary_cache is not None
|
||||
|
||||
def test_multiple_jobs_latest_wins(self):
|
||||
"""After multiple jobs, latest_summary should return the most recent."""
|
||||
inventory = self._create_inventory_with_hosts(1)
|
||||
self._run_job(inventory)
|
||||
self._run_job(inventory)
|
||||
job3 = self._run_job(inventory)
|
||||
|
||||
host = Host.objects.filter(inventory=inventory).with_latest_summary_id().first()
|
||||
assert host.latest_summary.job_id == job3.id
|
||||
|
||||
def test_partial_host_coverage(self):
|
||||
"""When a job only touches some hosts, only those hosts get summaries."""
|
||||
inventory = self._create_inventory_with_hosts(5)
|
||||
self._run_job(inventory, host_names=['host-0', 'host-1'])
|
||||
|
||||
hosts = list(Host.objects.filter(inventory=inventory).with_latest_summary_id().order_by('name'))
|
||||
with_summary = [h for h in hosts if h.latest_summary is not None]
|
||||
without_summary = [h for h in hosts if h.latest_summary is None]
|
||||
|
||||
assert len(with_summary) == 2
|
||||
assert len(without_summary) == 3
|
||||
assert sorted([h.name for h in with_summary]) == ['host-0', 'host-1']
|
||||
111
awx/main/tests/functional/models/test_host_summary_fields.py
Normal file
111
awx/main/tests/functional/models/test_host_summary_fields.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import pytest
|
||||
|
||||
from django.utils.timezone import now
|
||||
|
||||
from awx.main.models import Job, JobEvent, JobTemplate, Inventory, Host, JobHostSummary, Project
|
||||
from awx.api.serializers import HostSerializer
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestHostSummaryFields:
|
||||
"""Tests for summary_fields of last_job and last_job_host_summary on HostSerializer."""
|
||||
|
||||
def _setup_host_with_job(self, status='canceled'):
|
||||
inventory = Inventory()
|
||||
inventory.save()
|
||||
host = Host(created=now(), modified=now(), name='test-host', inventory=inventory)
|
||||
host.save()
|
||||
|
||||
project = Project(name='test-project')
|
||||
project.save()
|
||||
jt = JobTemplate(name='test-jt', inventory=inventory, project=project)
|
||||
jt.save()
|
||||
|
||||
job = Job(inventory=inventory, job_template=jt, status=status)
|
||||
if status in ('successful', 'failed', 'canceled', 'error'):
|
||||
job.finished = now()
|
||||
if status == 'canceled':
|
||||
job.canceled_on = now()
|
||||
job.save()
|
||||
|
||||
host_map = {host.name: host.id}
|
||||
JobEvent.create_from_data(
|
||||
job_id=job.pk,
|
||||
parent_uuid='abc123',
|
||||
event='playbook_on_stats',
|
||||
event_data={
|
||||
'ok': {host.name: 1},
|
||||
'changed': {},
|
||||
'dark': {},
|
||||
'failures': {},
|
||||
'ignored': {},
|
||||
'processed': {},
|
||||
'rescued': {},
|
||||
'skipped': {},
|
||||
},
|
||||
host_map=host_map,
|
||||
).save()
|
||||
|
||||
summary = JobHostSummary.objects.filter(host=host, job=job).first()
|
||||
host.last_job = job
|
||||
host.last_job_host_summary = summary
|
||||
host.save(update_fields=['last_job', 'last_job_host_summary'])
|
||||
host.refresh_from_db()
|
||||
|
||||
return host, job, summary
|
||||
|
||||
def test_last_job_summary_fields_canceled_job(self):
|
||||
host, job, summary = self._setup_host_with_job(status='canceled')
|
||||
|
||||
serializer = HostSerializer()
|
||||
d = serializer.get_summary_fields(host)
|
||||
|
||||
assert 'last_job' in d
|
||||
last_job = d['last_job']
|
||||
|
||||
expected_keys = {'id', 'name', 'description', 'finished', 'status', 'failed', 'canceled_on', 'job_template_id', 'job_template_name'}
|
||||
assert set(last_job.keys()) == expected_keys, f"Unexpected last_job keys: {set(last_job.keys())}"
|
||||
assert last_job['id'] == job.id
|
||||
assert last_job['status'] == 'canceled'
|
||||
assert last_job['canceled_on'] == job.canceled_on
|
||||
assert last_job['job_template_id'] == job.job_template.id
|
||||
assert last_job['job_template_name'] == job.job_template.name
|
||||
|
||||
def test_last_job_summary_fields_successful_job(self):
|
||||
host, job, summary = self._setup_host_with_job(status='successful')
|
||||
|
||||
serializer = HostSerializer()
|
||||
d = serializer.get_summary_fields(host)
|
||||
|
||||
assert 'last_job' in d
|
||||
last_job = d['last_job']
|
||||
|
||||
expected_keys = {'id', 'name', 'description', 'finished', 'status', 'failed', 'job_template_id', 'job_template_name'}
|
||||
assert set(last_job.keys()) == expected_keys, f"Unexpected last_job keys: {set(last_job.keys())}"
|
||||
assert last_job['id'] == job.id
|
||||
assert last_job['status'] == 'successful'
|
||||
assert 'canceled_on' not in last_job, "canceled_on should not appear when None"
|
||||
|
||||
def test_last_job_host_summary_fields(self):
|
||||
host, job, summary = self._setup_host_with_job(status='successful')
|
||||
|
||||
serializer = HostSerializer()
|
||||
d = serializer.get_summary_fields(host)
|
||||
|
||||
assert 'last_job_host_summary' in d
|
||||
last_jhs = d['last_job_host_summary']
|
||||
|
||||
assert last_jhs['id'] == summary.id
|
||||
assert 'failed' in last_jhs
|
||||
|
||||
def test_no_summary_fields_without_job(self):
|
||||
inventory = Inventory()
|
||||
inventory.save()
|
||||
host = Host(created=now(), modified=now(), name='lonely-host', inventory=inventory)
|
||||
host.save()
|
||||
|
||||
serializer = HostSerializer()
|
||||
d = serializer.get_summary_fields(host)
|
||||
|
||||
assert 'last_job' not in d
|
||||
assert 'last_job_host_summary' not in d
|
||||
@@ -108,6 +108,28 @@ class TestActiveCount:
|
||||
source.hosts.create(name='remotely-managed-host', inventory=inventory)
|
||||
assert Host.objects.active_count() == 1
|
||||
|
||||
def test_active_count_minus_constructed(self, organization):
|
||||
"""
|
||||
Active hosts do not include duplicated hosts from construted inventories.
|
||||
"""
|
||||
inv = Inventory.objects.create(name='source-inv', organization=organization)
|
||||
inv.hosts.create(name='host1')
|
||||
assert Host.objects.active_count() == 1
|
||||
|
||||
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
|
||||
Host.objects.create(name='host1', inventory=constructed)
|
||||
assert Host.objects.active_count() == 1
|
||||
|
||||
def test_org_active_count_minus_constructed(self, organization):
|
||||
"""Org-scoped count must also exclude constructed-inventory shadow rows."""
|
||||
inv = Inventory.objects.create(name='source-inv', organization=organization)
|
||||
inv.hosts.create(name='host1')
|
||||
assert Host.objects.org_active_count(organization.id) == 1
|
||||
|
||||
constructed = Inventory.objects.create(name='constructed-inv', kind='constructed', organization=organization)
|
||||
Host.objects.create(name='host1', inventory=constructed)
|
||||
assert Host.objects.org_active_count(organization.id) == 1
|
||||
|
||||
def test_host_case_insensitivity(self, organization):
|
||||
inv1 = Inventory.objects.create(name='inv1', organization=organization)
|
||||
inv2 = Inventory.objects.create(name='inv2', organization=organization)
|
||||
|
||||
@@ -8,7 +8,7 @@ from crum import impersonate
|
||||
# AWX
|
||||
from awx.main.models import UnifiedJobTemplate, Job, JobTemplate, WorkflowJobTemplate, Project, WorkflowJob, Schedule, Credential
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.constants import JOB_VARIABLE_PREFIXES
|
||||
from awx.main.utils.common import get_job_variable_prefixes
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -160,7 +160,13 @@ class TestMetaVars:
|
||||
job = Job.objects.create(name='job', created_by=admin_user)
|
||||
job.save()
|
||||
|
||||
user_vars = ['_'.join(x) for x in itertools.product(['tower', 'awx'], ['user_name', 'user_id', 'user_email', 'user_first_name', 'user_last_name'])]
|
||||
user_vars = [
|
||||
'_'.join(x)
|
||||
for x in itertools.product(
|
||||
get_job_variable_prefixes(),
|
||||
['user_name', 'user_id', 'user_email', 'user_first_name', 'user_last_name'],
|
||||
)
|
||||
]
|
||||
|
||||
for key in user_vars:
|
||||
assert key in job.awx_meta_vars()
|
||||
@@ -179,7 +185,7 @@ class TestMetaVars:
|
||||
|
||||
workflow_job.workflow_nodes.create(job=job)
|
||||
data = job.awx_meta_vars()
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
assert data['{}_user_id'.format(name)] == admin_user.id
|
||||
assert data['{}_user_name'.format(name)] == admin_user.username
|
||||
assert data['{}_workflow_job_id'.format(name)] == workflow_job.pk
|
||||
@@ -189,7 +195,7 @@ class TestMetaVars:
|
||||
schedule = Schedule.objects.create(name='job-schedule', rrule='DTSTART:20171129T155939z\nFREQ=MONTHLY', unified_job_template=job_template)
|
||||
job = Job.objects.create(name='fake-job', launch_type='workflow', schedule=schedule, job_template=job_template)
|
||||
data = job.awx_meta_vars()
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
assert data['{}_schedule_id'.format(name)] == schedule.pk
|
||||
assert '{}_user_name'.format(name) not in data
|
||||
|
||||
@@ -201,7 +207,7 @@ class TestMetaVars:
|
||||
job = Job.objects.create(launch_type='workflow')
|
||||
workflow_job.workflow_nodes.create(job=job)
|
||||
result_hash = {}
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
result_hash['{}_job_id'.format(name)] = job.id
|
||||
result_hash['{}_job_launch_type'.format(name)] = 'workflow'
|
||||
result_hash['{}_workflow_job_name'.format(name)] = 'workflow-job'
|
||||
|
||||
@@ -291,33 +291,6 @@ class TestWorkflowJob:
|
||||
assert set(data['labels']) == set(node_labels) # as exception, WFJT labels not applied
|
||||
assert data['limit'] == 'wj_limit'
|
||||
|
||||
def test_node_limit_not_overridden_by_empty_string_wj_limit(self, project, inventory):
|
||||
"""
|
||||
When the workflow job has an empty string limit (e.g., set via IaC with limit: ""),
|
||||
the node-level limit should still be passed to the spawned job, not silently suppressed.
|
||||
"""
|
||||
jt = JobTemplate.objects.create(
|
||||
project=project,
|
||||
inventory=inventory,
|
||||
ask_limit_on_launch=True,
|
||||
)
|
||||
# Simulate a workflow job whose WFJT was created via IaC with `limit: ""`
|
||||
# (e.g. awx.awx.workflow_job_template: ... limit: "")
|
||||
# This stores '' in char_prompts instead of treating it as None/"no limit".
|
||||
wj = WorkflowJob.objects.create(name='test-wf-job')
|
||||
wj.limit = '' # stores {'limit': ''} in char_prompts - the IaC bug scenario
|
||||
wj.save()
|
||||
|
||||
node = WorkflowJobNode.objects.create(workflow_job=wj, unified_job_template=jt)
|
||||
node.limit = 'web_servers'
|
||||
node.save()
|
||||
|
||||
data = node.get_job_kwargs()
|
||||
# The node-level limit should be applied; the WJ's empty string limit is not meaningful
|
||||
assert data.get('limit') == 'web_servers', (
|
||||
"Node-level limit 'web_servers' was not passed to the job. " "Likely caused by an empty string WJ limit overriding the node limit"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestWorkflowJobTemplate:
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
import urllib.parse
|
||||
|
||||
import pytest
|
||||
|
||||
from awx.api.versioning import reverse
|
||||
from awx.main.models import (
|
||||
Group,
|
||||
Host,
|
||||
Inventory,
|
||||
Organization,
|
||||
Schedule,
|
||||
)
|
||||
from awx.main.access import (
|
||||
@@ -128,3 +134,94 @@ class TestSmartInventory:
|
||||
assert InventoryAccess(org_admin).can_admin(smart_inventory, {'host_filter': 'search=foo'})
|
||||
smart_inventory.admin_role.members.add(rando)
|
||||
assert not InventoryAccess(rando).can_admin(smart_inventory, {'host_filter': 'search=foo'})
|
||||
|
||||
def test_host_filter_edit_unprivileged(self, smart_inventory, user):
|
||||
unprivileged = user('unprivileged', False)
|
||||
assert not InventoryAccess(unprivileged).can_change(smart_inventory, None)
|
||||
assert not InventoryAccess(unprivileged).can_admin(smart_inventory, {'host_filter': 'search=bar'})
|
||||
|
||||
def test_host_filter_edit_inventory_admin_role(self, smart_inventory, user):
|
||||
inv_admin = user('inv_admin', False)
|
||||
smart_inventory.admin_role.members.add(inv_admin)
|
||||
assert InventoryAccess(inv_admin).can_change(smart_inventory, None)
|
||||
assert not InventoryAccess(inv_admin).can_admin(smart_inventory, {'host_filter': 'search=bar'})
|
||||
|
||||
def test_host_filter_edit_org_admin_via_api(self, smart_inventory, patch, user):
|
||||
oa = user('smart_oa', False)
|
||||
smart_inventory.organization.admin_role.members.add(oa)
|
||||
url = reverse('api:inventory_detail', kwargs={'pk': smart_inventory.pk})
|
||||
resp = patch(url, {'host_filter': 'search=bar'}, oa, expect=200)
|
||||
assert resp.data['host_filter'] == 'search=bar'
|
||||
|
||||
@pytest.mark.parametrize("role_field", ['admin_role', 'use_role', 'adhoc_role', 'read_role'])
|
||||
def test_inventory_role_cannot_edit_host_filter(self, smart_inventory, patch, user, role_field):
|
||||
u = user('role_test_user', False)
|
||||
getattr(smart_inventory, role_field).members.add(u)
|
||||
url = reverse('api:inventory_detail', kwargs={'pk': smart_inventory.pk})
|
||||
patch(url, {'host_filter': 'search=bar'}, u, expect=403)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestHostFilterRBAC:
|
||||
@pytest.fixture
|
||||
def two_org_inventories(self):
|
||||
orgA = Organization.objects.create(name="rbac-orgA")
|
||||
orgB = Organization.objects.create(name="rbac-orgB")
|
||||
invA = Inventory.objects.create(name="rbac-invA", organization=orgA)
|
||||
invB = Inventory.objects.create(name="rbac-invB", organization=orgB)
|
||||
hostA = Host.objects.create(name="shared_name", inventory=invA)
|
||||
hostB = Host.objects.create(name="shared_name", inventory=invB)
|
||||
groupA = Group.objects.create(name="shared_group", inventory=invA)
|
||||
groupB = Group.objects.create(name="shared_group", inventory=invB)
|
||||
groupA.hosts.add(hostA)
|
||||
groupB.hosts.add(hostB)
|
||||
return {
|
||||
'orgA': orgA,
|
||||
'orgB': orgB,
|
||||
'invA': invA,
|
||||
'invB': invB,
|
||||
'hostA': hostA,
|
||||
'hostB': hostB,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("host_filter", ["name=shared_name", "groups__name=shared_group"])
|
||||
def test_host_filter_scoped_to_inventory_read_role(self, two_org_inventories, get, user, host_filter):
|
||||
data = two_org_inventories
|
||||
userA = user('rbac_userA', False)
|
||||
userB = user('rbac_userB', False)
|
||||
data['invA'].read_role.members.add(userA)
|
||||
data['invB'].read_role.members.add(userB)
|
||||
|
||||
url = reverse('api:host_list')
|
||||
params = "?host_filter=%s" % urllib.parse.quote(host_filter, safe='')
|
||||
|
||||
respA = get(url + params, userA)
|
||||
idsA = [h['id'] for h in respA.data['results']]
|
||||
assert data['hostA'].id in idsA
|
||||
assert data['hostB'].id not in idsA
|
||||
|
||||
respB = get(url + params, userB)
|
||||
idsB = [h['id'] for h in respB.data['results']]
|
||||
assert data['hostB'].id in idsB
|
||||
assert data['hostA'].id not in idsB
|
||||
|
||||
@pytest.mark.parametrize("host_filter", ["name=shared_name", "groups__name=shared_group"])
|
||||
def test_host_filter_scoped_to_org_admin(self, two_org_inventories, get, user, host_filter):
|
||||
data = two_org_inventories
|
||||
adminA = user('rbac_adminA', False)
|
||||
adminB = user('rbac_adminB', False)
|
||||
data['orgA'].admin_role.members.add(adminA)
|
||||
data['orgB'].admin_role.members.add(adminB)
|
||||
|
||||
url = reverse('api:host_list')
|
||||
params = "?host_filter=%s" % urllib.parse.quote(host_filter, safe='')
|
||||
|
||||
respA = get(url + params, adminA)
|
||||
idsA = [h['id'] for h in respA.data['results']]
|
||||
assert data['hostA'].id in idsA
|
||||
assert data['hostB'].id not in idsA
|
||||
|
||||
respB = get(url + params, adminB)
|
||||
idsB = [h['id'] for h in respB.data['results']]
|
||||
assert data['hostB'].id in idsB
|
||||
assert data['hostA'].id not in idsB
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.management.base import CommandError
|
||||
|
||||
from awx.main.tasks.system import _sync_credential_types_to_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -9,18 +12,38 @@ def mock_setup_tower_managed_defaults(mocker):
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_load_credential_types_feature_migrations_ran(mocker, mock_setup_tower_managed_defaults):
|
||||
mocker.patch('awx.main.apps.is_database_synchronized', return_value=True)
|
||||
def test_sync_credential_types_migrations_ran(mocker, mock_setup_tower_managed_defaults):
|
||||
mocker.patch('awx.main.tasks.system.is_database_synchronized', return_value=True)
|
||||
|
||||
apps.get_app_config('main')._load_credential_types_feature()
|
||||
_sync_credential_types_to_db()
|
||||
|
||||
mock_setup_tower_managed_defaults.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_load_credential_types_feature_migrations_not_ran(mocker, mock_setup_tower_managed_defaults):
|
||||
mocker.patch('awx.main.apps.is_database_synchronized', return_value=False)
|
||||
def test_sync_credential_types_migrations_not_ran(mocker, mock_setup_tower_managed_defaults):
|
||||
mocker.patch('awx.main.tasks.system.is_database_synchronized', return_value=False)
|
||||
|
||||
apps.get_app_config('main')._load_credential_types_feature()
|
||||
_sync_credential_types_to_db()
|
||||
|
||||
mock_setup_tower_managed_defaults.assert_not_called()
|
||||
|
||||
|
||||
def test_check_db_requirement_no_violations(mocker):
|
||||
mocker.patch('awx.main.apps.db_requirement_violations', return_value=None)
|
||||
main_config = apps.get_app_config('main')
|
||||
|
||||
result = main_config.check_db_requirement()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_check_db_requirement_with_violations(mocker):
|
||||
violation_msg = "Database version check failed"
|
||||
mocker.patch('awx.main.apps.db_requirement_violations', return_value=violation_msg)
|
||||
main_config = apps.get_app_config('main')
|
||||
|
||||
with pytest.raises(CommandError) as exc_info:
|
||||
main_config.check_db_requirement()
|
||||
|
||||
assert str(exc_info.value) == violation_msg
|
||||
|
||||
@@ -8,6 +8,8 @@ from awx.main.models.jobs import JobTemplate
|
||||
from awx.main.models import Organization, Inventory, WorkflowJob, ExecutionEnvironment, Host
|
||||
from awx.main.scheduler import TaskManager
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.parametrize('num_hosts, num_queries', [(1, 15), (10, 15)])
|
||||
@@ -445,3 +447,185 @@ def get_inventory_hosts(get, inv_id, use_user):
|
||||
data = get(reverse('api:inventory_hosts_list', kwargs={'pk': inv_id}), use_user, expect=200).data
|
||||
results = [host['id'] for host in data['results']]
|
||||
return results
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_job_launch_respects_settings_limit(job_template, organization, inventory, project, post, patch, get, user):
|
||||
"""Test that bulk job launch respects BULK_JOB_MAX_LAUNCH setting."""
|
||||
normal_user = user('normal_user', False)
|
||||
organization.member_role.members.add(normal_user)
|
||||
|
||||
jt = JobTemplate.objects.create(
|
||||
name='bulk-test-jt',
|
||||
ask_inventory_on_launch=True,
|
||||
project=project,
|
||||
playbook='helloworld.yml',
|
||||
allow_simultaneous=True,
|
||||
)
|
||||
jt.execute_role.members.add(normal_user)
|
||||
inventory.use_role.members.add(normal_user)
|
||||
|
||||
# Test with limit set to 3
|
||||
with override_settings(BULK_JOB_MAX_LAUNCH=3):
|
||||
# Attempt to launch 5 jobs when limit is 3 - should fail
|
||||
jobs = [{'unified_job_template': jt.id, 'inventory': inventory.id} for _ in range(5)]
|
||||
resp = post(
|
||||
reverse('api:bulk_job_launch'),
|
||||
{'name': 'Bulk Job Test', 'jobs': jobs},
|
||||
normal_user,
|
||||
expect=400,
|
||||
)
|
||||
assert 'Number of requested jobs exceeds system setting' in str(resp.data)
|
||||
|
||||
# Test with limit increased to 10
|
||||
with override_settings(BULK_JOB_MAX_LAUNCH=10):
|
||||
# Now launching 5 jobs should succeed
|
||||
jobs = [{'unified_job_template': jt.id, 'inventory': inventory.id} for _ in range(5)]
|
||||
resp = post(
|
||||
reverse('api:bulk_job_launch'),
|
||||
{'name': 'Bulk Job Test', 'jobs': jobs},
|
||||
normal_user,
|
||||
expect=201,
|
||||
)
|
||||
bulk_job = get(resp.data['url'], normal_user, expect=200).data
|
||||
# Verify the workflow job was created
|
||||
assert bulk_job['name'] == 'Bulk Job Test'
|
||||
|
||||
|
||||
# Tests for BulkHostCreateSerializer duplicate detection optimization
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_duplicate_within_batch(organization, inventory, post, user):
|
||||
"""
|
||||
Test that duplicate hostnames within the same batch are detected.
|
||||
This tests the Counter-based duplicate detection logic.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inv_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inv_admin)
|
||||
inventory.admin_role.members.add(inv_admin)
|
||||
|
||||
# Try to create hosts where 'duplicate-host' appears twice in the same batch
|
||||
hosts = [
|
||||
{'name': 'unique-host-1'},
|
||||
{'name': 'duplicate-host'},
|
||||
{'name': 'unique-host-2'},
|
||||
{'name': 'duplicate-host'}, # Duplicate within batch
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inv_admin, expect=400)
|
||||
|
||||
assert 'Hostnames must be unique in an inventory' in response.data['__all__'][0]
|
||||
assert 'duplicate-host' in response.data['__all__'][0]
|
||||
assert Host.objects.filter(inventory=inventory).count() == 0
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_duplicate_against_existing(organization, inventory, post, user):
|
||||
"""
|
||||
Test that duplicate hostnames against existing inventory hosts are detected.
|
||||
This tests the database query-based duplicate detection.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inv_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inv_admin)
|
||||
inventory.admin_role.members.add(inv_admin)
|
||||
|
||||
Host.objects.create(name='existing-host-1', inventory=inventory)
|
||||
Host.objects.create(name='existing-host-2', inventory=inventory)
|
||||
|
||||
# Try to create hosts where one already exists
|
||||
hosts = [
|
||||
{'name': 'new-host-1'},
|
||||
{'name': 'existing-host-1'},
|
||||
{'name': 'new-host-2'},
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inv_admin, expect=400)
|
||||
|
||||
assert 'Hostnames must be unique in an inventory' in response.data['__all__'][0]
|
||||
assert 'existing-host-1' in response.data['__all__'][0]
|
||||
assert Host.objects.filter(inventory=inventory).count() == 2
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_combined_duplicates(organization, inventory, post, user):
|
||||
"""
|
||||
Test detection of both batch-internal duplicates and duplicates against existing hosts.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inventory_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inventory_admin)
|
||||
inventory.admin_role.members.add(inventory_admin)
|
||||
|
||||
Host.objects.create(name='existing-host', inventory=inventory)
|
||||
|
||||
# Try to create hosts with both types of duplicates
|
||||
hosts = [
|
||||
{'name': 'new-host'},
|
||||
{'name': 'batch-duplicate'},
|
||||
{'name': 'existing-host'},
|
||||
{'name': 'batch-duplicate'},
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inventory_admin, expect=400)
|
||||
|
||||
error_message = response.data['__all__'][0]
|
||||
assert 'Hostnames must be unique in an inventory' in error_message
|
||||
assert 'batch-duplicate' in error_message or 'existing-host' in error_message
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_no_duplicates_success(organization, inventory, post, user):
|
||||
"""
|
||||
Test that hosts are created successfully when there are no duplicates.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inventory_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inventory_admin)
|
||||
inventory.admin_role.members.add(inventory_admin)
|
||||
|
||||
Host.objects.create(name='existing-host-1', inventory=inventory)
|
||||
Host.objects.create(name='existing-host-2', inventory=inventory)
|
||||
|
||||
# Create new hosts with unique names
|
||||
hosts = [
|
||||
{'name': 'new-host-1'},
|
||||
{'name': 'new-host-2'},
|
||||
{'name': 'new-host-3'},
|
||||
]
|
||||
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': hosts}, inventory_admin, expect=201)
|
||||
|
||||
assert len(response.data['hosts']) == 3
|
||||
assert Host.objects.filter(inventory=inventory).count() == 5
|
||||
assert Host.objects.filter(inventory=inventory, name='new-host-1').exists()
|
||||
assert Host.objects.filter(inventory=inventory, name='new-host-2').exists()
|
||||
assert Host.objects.filter(inventory=inventory, name='new-host-3').exists()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_bulk_host_create_performance_large_inventory(organization, inventory, post, user, django_assert_max_num_queries):
|
||||
"""
|
||||
Test that duplicate detection is performant and doesn't load all hosts.
|
||||
"""
|
||||
inventory.organization = organization
|
||||
inventory_admin = user('inventory_admin', False)
|
||||
organization.member_role.members.add(inventory_admin)
|
||||
inventory.admin_role.members.add(inventory_admin)
|
||||
|
||||
# Create 10k existing hosts to simulate a reasonably large inventory
|
||||
from django.utils.timezone import now
|
||||
|
||||
_now = now()
|
||||
existing_hosts = [Host(name=f'existing-host-{i}', inventory=inventory, created=_now, modified=_now) for i in range(10000)]
|
||||
Host.objects.bulk_create(existing_hosts)
|
||||
|
||||
new_hosts = [{'name': f'new-host-{i}'} for i in range(10)]
|
||||
|
||||
# The number of queries should be bounded and not scale with inventory size
|
||||
# This should be around 15-20 queries regardless of whether there are 10k or 500k+ existing hosts
|
||||
with django_assert_max_num_queries(20):
|
||||
response = post(reverse('api:bulk_host_create'), {'inventory': inventory.id, 'hosts': new_hosts}, inventory_admin, expect=201)
|
||||
|
||||
assert len(response.data['hosts']) == 10
|
||||
assert Host.objects.filter(inventory=inventory).count() == 10010
|
||||
|
||||
@@ -160,3 +160,38 @@ class TestJobReaper(object):
|
||||
assert job.started > ref_time
|
||||
assert job.status == 'running'
|
||||
assert job.job_explanation == ''
|
||||
|
||||
def test_waiting_job_reset_when_controller_node_deprovisioned(self):
|
||||
"""When a controller pod is replaced (e.g. K8s rollout), waiting jobs
|
||||
assigned to the now-gone controller_node should be reset to pending
|
||||
by the task manager so they can be re-dispatched."""
|
||||
from awx.main.scheduler import TaskManager
|
||||
|
||||
live_inst = Instance(hostname='awx-task-live', node_type='control')
|
||||
live_inst.save()
|
||||
# No instance record for 'awx-task-dead' — it was already deprovisioned
|
||||
job = Job.objects.create(status='waiting', controller_node='awx-task-dead', execution_node='')
|
||||
|
||||
tm = TaskManager()
|
||||
tm.reap_jobs_from_orphaned_instances()
|
||||
|
||||
job.refresh_from_db()
|
||||
assert job.status == 'pending'
|
||||
assert job.controller_node == ''
|
||||
assert job.execution_node == ''
|
||||
|
||||
@pytest.mark.parametrize('node_type', ['control', 'hybrid'])
|
||||
def test_waiting_job_not_reset_when_controller_node_alive(self, node_type):
|
||||
"""Waiting jobs on a live control or hybrid node should not be touched."""
|
||||
from awx.main.scheduler import TaskManager
|
||||
|
||||
live_inst = Instance(hostname='awx-task-live', node_type=node_type)
|
||||
live_inst.save()
|
||||
job = Job.objects.create(status='waiting', controller_node='awx-task-live', execution_node='')
|
||||
|
||||
tm = TaskManager()
|
||||
tm.reap_jobs_from_orphaned_instances()
|
||||
|
||||
job.refresh_from_db()
|
||||
assert job.status == 'waiting'
|
||||
assert job.controller_node == 'awx-task-live'
|
||||
|
||||
@@ -287,6 +287,20 @@ def test_control_plane_policy_exception(controlplane_instance_group):
|
||||
assert 'foo-1' not in [inst.hostname for inst in controlplane_instance_group.instances.all()]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_policy_instance_list_controlplane_excludes_execution_node(controlplane_instance_group):
|
||||
controlplane_instance_group.policy_instance_percentage = 100
|
||||
controlplane_instance_group.save()
|
||||
exec_inst = Instance.objects.create(hostname='exec-1', node_type='execution')
|
||||
control_inst = Instance.objects.create(hostname='control-1', node_type='control')
|
||||
controlplane_instance_group.policy_instance_list = [exec_inst.hostname]
|
||||
controlplane_instance_group.save()
|
||||
apply_cluster_membership_policies()
|
||||
members = list(controlplane_instance_group.instances.all())
|
||||
assert exec_inst not in members
|
||||
assert control_inst in members
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_normal_instance_group_policy_exception():
|
||||
ig = InstanceGroup.objects.create(name='bar', policy_instance_percentage=100, policy_instance_minimum=2)
|
||||
|
||||
@@ -486,7 +486,7 @@ def test_populate_workload_identity_tokens_with_flag_enabled(job_template_with_c
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
@@ -495,7 +495,7 @@ def test_populate_workload_identity_tokens_with_flag_enabled(job_template_with_c
|
||||
|
||||
# Create credentials
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'jwt_aud': 'https://vault.example.com'})
|
||||
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'url': 'https://vault.example.com'})
|
||||
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
|
||||
|
||||
# Create input source linking source credential to target credential
|
||||
@@ -545,7 +545,7 @@ def test_populate_workload_identity_tokens_passes_workload_ttl_from_job_timeout(
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
@@ -553,7 +553,7 @@ def test_populate_workload_identity_tokens_passes_workload_ttl_from_job_timeout(
|
||||
hashivault_type.save()
|
||||
|
||||
ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred')
|
||||
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'jwt_aud': 'https://vault.example.com'})
|
||||
source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'url': 'https://vault.example.com'})
|
||||
target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'})
|
||||
|
||||
CredentialInputSource.objects.create(
|
||||
@@ -595,7 +595,7 @@ def test_populate_workload_identity_tokens_with_flag_disabled(job_template_with_
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
@@ -647,7 +647,7 @@ def test_populate_workload_identity_tokens_multiple_input_sources_per_credential
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
@@ -660,7 +660,7 @@ def test_populate_workload_identity_tokens_multiple_input_sources_per_credential
|
||||
managed=False,
|
||||
inputs={
|
||||
'fields': [
|
||||
{'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'},
|
||||
{'id': 'url', 'type': 'string', 'label': 'Server URL'},
|
||||
{'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True},
|
||||
]
|
||||
},
|
||||
@@ -668,11 +668,9 @@ def test_populate_workload_identity_tokens_multiple_input_sources_per_credential
|
||||
hashivault_ssh_type.save()
|
||||
|
||||
# Create source credentials with different audiences
|
||||
source_cred_kv = Credential.objects.create(
|
||||
credential_type=hashivault_kv_type, name='vault-kv-source', inputs={'jwt_aud': 'https://vault-kv.example.com'}
|
||||
)
|
||||
source_cred_kv = Credential.objects.create(credential_type=hashivault_kv_type, name='vault-kv-source', inputs={'url': 'https://vault-kv.example.com'})
|
||||
source_cred_ssh = Credential.objects.create(
|
||||
credential_type=hashivault_ssh_type, name='vault-ssh-source', inputs={'jwt_aud': 'https://vault-ssh.example.com'}
|
||||
credential_type=hashivault_ssh_type, name='vault-ssh-source', inputs={'url': 'https://vault-ssh.example.com'}
|
||||
)
|
||||
|
||||
# Create target credential that uses both sources for different fields
|
||||
|
||||
206
awx/main/tests/live/tests/test_nested_workflow_artifacts.py
Normal file
206
awx/main/tests/live/tests/test_nested_workflow_artifacts.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from awx.main.tests.live.tests.conftest import wait_for_job
|
||||
|
||||
from awx.main.models import JobTemplate, WorkflowJobTemplate, WorkflowJobTemplateNode
|
||||
|
||||
JT_NAMES = ('artifact-test-first', 'artifact-test-second', 'artifact-test-reader')
|
||||
WFT_NAMES = ('artifact-test-outer-wf', 'artifact-test-inner-wf')
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_nested_workflow_set_stats_precedence(live_tmp_folder, demo_inv, project_factory, default_org):
|
||||
"""Reproducer for set_stats artifacts from an outer workflow leaking into
|
||||
an inner (child) workflow and overriding the inner workflow's own artifacts.
|
||||
|
||||
Outer WF: [job_first] --success--> [inner_wf]
|
||||
Inner WF: [job_second] --success--> [job_reader]
|
||||
|
||||
job_first sets via set_stats:
|
||||
var1: "outer-only" (only source, should propagate through)
|
||||
var2: "should-be-overridden" (will be overridden by job_second)
|
||||
|
||||
job_second sets via set_stats:
|
||||
var2: "from-inner" (should override outer's value)
|
||||
var3: "inner-only" (only source, should be available)
|
||||
|
||||
job_reader runs debug.yml (no set_stats), we inspect its extra_vars:
|
||||
var1 should be "outer-only" - outer artifacts propagate when uncontested
|
||||
var2 should be "from-inner" - inner artifacts override outer (THE BUG)
|
||||
var3 should be "inner-only" - inner-only artifacts propagate normally
|
||||
"""
|
||||
# Clean up resources from prior runs (delete individually for signals)
|
||||
for name in WFT_NAMES:
|
||||
for wft in WorkflowJobTemplate.objects.filter(name=name):
|
||||
wft.delete()
|
||||
for name in JT_NAMES:
|
||||
for jt in JobTemplate.objects.filter(name=name):
|
||||
jt.delete()
|
||||
|
||||
proj = project_factory(scm_url=f'file://{live_tmp_folder}/debug')
|
||||
if proj.current_job:
|
||||
wait_for_job(proj.current_job)
|
||||
|
||||
# job_first: sets var1 (outer-only) and var2 (to be overridden by inner)
|
||||
jt_first = JobTemplate.objects.create(
|
||||
name='artifact-test-first',
|
||||
project=proj,
|
||||
playbook='set_stats.yml',
|
||||
inventory=demo_inv,
|
||||
extra_vars=json.dumps({'stats_data': {'var1': 'outer-only', 'var2': 'should-be-overridden'}}),
|
||||
)
|
||||
# job_second: overrides var2, introduces var3
|
||||
jt_second = JobTemplate.objects.create(
|
||||
name='artifact-test-second',
|
||||
project=proj,
|
||||
playbook='set_stats.yml',
|
||||
inventory=demo_inv,
|
||||
extra_vars=json.dumps({'stats_data': {'var2': 'from-inner', 'var3': 'inner-only'}}),
|
||||
)
|
||||
# job_reader: just runs, we check what extra_vars it receives
|
||||
jt_reader = JobTemplate.objects.create(
|
||||
name='artifact-test-reader',
|
||||
project=proj,
|
||||
playbook='debug.yml',
|
||||
inventory=demo_inv,
|
||||
)
|
||||
|
||||
# Inner WFT: job_second -> job_reader
|
||||
inner_wft = WorkflowJobTemplate.objects.create(name='artifact-test-inner-wf', organization=default_org)
|
||||
inner_node_1 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=inner_wft,
|
||||
unified_job_template=jt_second,
|
||||
identifier='second',
|
||||
)
|
||||
inner_node_2 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=inner_wft,
|
||||
unified_job_template=jt_reader,
|
||||
identifier='reader',
|
||||
)
|
||||
inner_node_1.success_nodes.add(inner_node_2)
|
||||
|
||||
# Outer WFT: job_first -> inner_wf
|
||||
outer_wft = WorkflowJobTemplate.objects.create(name='artifact-test-outer-wf', organization=default_org)
|
||||
outer_node_1 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=outer_wft,
|
||||
unified_job_template=jt_first,
|
||||
identifier='first',
|
||||
)
|
||||
outer_node_2 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=outer_wft,
|
||||
unified_job_template=inner_wft,
|
||||
identifier='inner',
|
||||
)
|
||||
outer_node_1.success_nodes.add(outer_node_2)
|
||||
|
||||
# Launch and wait
|
||||
outer_wfj = outer_wft.create_unified_job()
|
||||
outer_wfj.signal_start()
|
||||
wait_for_job(outer_wfj, running_timeout=120)
|
||||
|
||||
# Find the reader job inside the inner workflow
|
||||
inner_wf_node = outer_wfj.workflow_job_nodes.get(identifier='inner')
|
||||
inner_wfj = inner_wf_node.job
|
||||
assert inner_wfj is not None, 'Inner workflow job was never created'
|
||||
|
||||
# Check that root node of inner WF (job_second) received outer artifacts
|
||||
second_node = inner_wfj.workflow_job_nodes.get(identifier='second')
|
||||
assert second_node.job is not None, 'Second job was never created'
|
||||
second_extra_vars = json.loads(second_node.job.extra_vars)
|
||||
assert second_extra_vars.get('var1') == 'outer-only', (
|
||||
f'Root node var1: expected "outer-only" (outer artifact should be available to root node), '
|
||||
f'got "{second_extra_vars.get("var1")}". '
|
||||
f'Outer artifacts are not reaching root nodes of child workflows.'
|
||||
)
|
||||
|
||||
reader_node = inner_wfj.workflow_job_nodes.get(identifier='reader')
|
||||
assert reader_node.job is not None, 'Reader job was never created'
|
||||
|
||||
reader_extra_vars = json.loads(reader_node.job.extra_vars)
|
||||
|
||||
# var1: only set by outer job_first, no conflict — should propagate through
|
||||
assert reader_extra_vars.get('var1') == 'outer-only', f'var1: expected "outer-only" (uncontested outer artifact), ' f'got "{reader_extra_vars.get("var1")}"'
|
||||
|
||||
# var2: set by outer as "should-be-overridden", then by inner as "from-inner"
|
||||
# Inner workflow's own ancestor artifacts should take precedence
|
||||
assert reader_extra_vars.get('var2') == 'from-inner', (
|
||||
f'var2: expected "from-inner" (inner workflow artifact should override outer), '
|
||||
f'got "{reader_extra_vars.get("var2")}". '
|
||||
f'Outer workflow artifacts are leaking via wj_special_vars. '
|
||||
f'reader node ancestor_artifacts={reader_node.ancestor_artifacts}'
|
||||
)
|
||||
|
||||
# var3: only set by inner job_second — should propagate normally
|
||||
assert reader_extra_vars.get('var3') == 'inner-only', f'var3: expected "inner-only" (inner-only artifact), ' f'got "{reader_extra_vars.get("var3")}"'
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_workflow_extra_vars_override_artifacts(live_tmp_folder, demo_inv, project_factory, default_org):
|
||||
"""Workflow extra_vars should take precedence over set_stats artifacts
|
||||
within a single (non-nested) workflow.
|
||||
|
||||
WF (extra_vars: my_var="from-wf-extra-vars"):
|
||||
[job_setter] --success--> [job_reader]
|
||||
|
||||
job_setter sets my_var="from-set-stats" via set_stats
|
||||
job_reader should see my_var="from-wf-extra-vars" because workflow
|
||||
extra_vars are higher precedence than ancestor artifacts.
|
||||
"""
|
||||
wft_name = 'artifact-test-wf-extra-vars-precedence'
|
||||
jt_names = ('artifact-test-setter', 'artifact-test-checker')
|
||||
|
||||
for wft in WorkflowJobTemplate.objects.filter(name=wft_name):
|
||||
wft.delete()
|
||||
for name in jt_names:
|
||||
for jt in JobTemplate.objects.filter(name=name):
|
||||
jt.delete()
|
||||
|
||||
proj = project_factory(scm_url=f'file://{live_tmp_folder}/debug')
|
||||
if proj.current_job:
|
||||
wait_for_job(proj.current_job)
|
||||
|
||||
jt_setter = JobTemplate.objects.create(
|
||||
name='artifact-test-setter',
|
||||
project=proj,
|
||||
playbook='set_stats.yml',
|
||||
inventory=demo_inv,
|
||||
extra_vars=json.dumps({'stats_data': {'my_var': 'from-set-stats'}}),
|
||||
)
|
||||
jt_checker = JobTemplate.objects.create(
|
||||
name='artifact-test-checker',
|
||||
project=proj,
|
||||
playbook='debug.yml',
|
||||
inventory=demo_inv,
|
||||
)
|
||||
|
||||
wft = WorkflowJobTemplate.objects.create(
|
||||
name=wft_name,
|
||||
organization=default_org,
|
||||
extra_vars=json.dumps({'my_var': 'from-wf-extra-vars'}),
|
||||
)
|
||||
node_1 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=wft,
|
||||
unified_job_template=jt_setter,
|
||||
identifier='setter',
|
||||
)
|
||||
node_2 = WorkflowJobTemplateNode.objects.create(
|
||||
workflow_job_template=wft,
|
||||
unified_job_template=jt_checker,
|
||||
identifier='checker',
|
||||
)
|
||||
node_1.success_nodes.add(node_2)
|
||||
|
||||
wfj = wft.create_unified_job()
|
||||
wfj.signal_start()
|
||||
wait_for_job(wfj, running_timeout=120)
|
||||
|
||||
checker_node = wfj.workflow_job_nodes.get(identifier='checker')
|
||||
assert checker_node.job is not None, 'Checker job was never created'
|
||||
|
||||
checker_extra_vars = json.loads(checker_node.job.extra_vars)
|
||||
assert checker_extra_vars.get('my_var') == 'from-wf-extra-vars', (
|
||||
f'Expected my_var="from-wf-extra-vars" (workflow extra_vars should override artifacts), '
|
||||
f'got my_var="{checker_extra_vars.get("my_var")}". '
|
||||
f'checker node ancestor_artifacts={checker_node.ancestor_artifacts}'
|
||||
)
|
||||
320
awx/main/tests/live/tests/test_smart_inventory.py
Normal file
320
awx/main/tests/live/tests/test_smart_inventory.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Smart inventory tests that require PostgreSQL.
|
||||
|
||||
These tests exercise SmartFilter and smart inventory host resolution against
|
||||
a real PostgreSQL database. Most are unit-style tests that set ansible_facts
|
||||
directly on Host objects rather than running playbooks.
|
||||
|
||||
The smart inventory HostManager uses DISTINCT ON which requires PostgreSQL,
|
||||
so any test that reads smart inventory hosts must run here (not in functional/).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from awx.main.models import Organization, Inventory, Host, Group
|
||||
from awx.main.utils.filters import SmartFilter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fact_org():
|
||||
org, _ = Organization.objects.get_or_create(name='smart-inv-fact-test-org')
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fact_inventory(fact_org):
|
||||
inv, created = Inventory.objects.get_or_create(name='smart-inv-fact-test-inv', organization=fact_org)
|
||||
if not created:
|
||||
inv.hosts.all().delete()
|
||||
inv.groups.all().delete()
|
||||
|
||||
groupA = Group.objects.create(name='factGroupA', inventory=inv)
|
||||
groupB = Group.objects.create(name='factGroupB', inventory=inv)
|
||||
|
||||
hostA = Host.objects.create(
|
||||
name='factHostA',
|
||||
inventory=inv,
|
||||
ansible_facts={
|
||||
'ansible_system': 'Linux',
|
||||
'ansible_distribution': 'CentOS',
|
||||
'ansible_python': {
|
||||
'version': {'major': 3, 'minor': 9, 'micro': 7},
|
||||
'version_info': [3, 9, 7, 'final', 0],
|
||||
},
|
||||
'ansible_env': {'HOME': '/root'},
|
||||
},
|
||||
)
|
||||
hostB = Host.objects.create(
|
||||
name='factHostB',
|
||||
inventory=inv,
|
||||
ansible_facts={
|
||||
'ansible_system': 'Linux',
|
||||
'ansible_distribution': 'Ubuntu',
|
||||
'ansible_python': {
|
||||
'version': {'major': 3, 'minor': 11, 'micro': 2},
|
||||
'version_info': [3, 11, 2, 'final', 0],
|
||||
},
|
||||
'ansible_env': {'HOME': '/home/user'},
|
||||
},
|
||||
)
|
||||
hostC = Host.objects.create(
|
||||
name='factHostC',
|
||||
inventory=inv,
|
||||
ansible_facts={
|
||||
'ansible_system': 'Darwin',
|
||||
'ansible_distribution': 'MacOSX',
|
||||
'ansible_python': {
|
||||
'version': {'major': 3, 'minor': 10, 'micro': 0},
|
||||
'version_info': [3, 10, 0, 'final', 0],
|
||||
},
|
||||
'ansible_env': {'HOME': '/Users/test'},
|
||||
},
|
||||
)
|
||||
|
||||
groupA.hosts.add(hostA, hostC)
|
||||
groupB.hosts.add(hostB, hostC)
|
||||
|
||||
yield {
|
||||
'org': fact_org,
|
||||
'inv': inv,
|
||||
'hosts': {'hostA': hostA, 'hostB': hostB, 'hostC': hostC},
|
||||
'groups': {'groupA': groupA, 'groupB': groupB},
|
||||
}
|
||||
|
||||
hostA.delete()
|
||||
hostB.delete()
|
||||
hostC.delete()
|
||||
groupA.delete()
|
||||
groupB.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smart_inventory_factory():
|
||||
created = []
|
||||
|
||||
def _factory(name, host_filter, organization):
|
||||
inv = Inventory.objects.create(name=name, kind='smart', host_filter=host_filter, organization=organization)
|
||||
created.append(inv)
|
||||
return inv
|
||||
|
||||
yield _factory
|
||||
for inv in reversed(created):
|
||||
inv.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def host_factory():
|
||||
created = []
|
||||
|
||||
def _factory(**kwargs):
|
||||
host = Host.objects.create(**kwargs)
|
||||
created.append(host)
|
||||
return host
|
||||
|
||||
yield _factory
|
||||
for host in reversed(created):
|
||||
if host.pk is not None:
|
||||
host.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def group_factory():
|
||||
created = []
|
||||
|
||||
def _factory(**kwargs):
|
||||
group = Group.objects.create(**kwargs)
|
||||
created.append(group)
|
||||
return group
|
||||
|
||||
yield _factory
|
||||
for group in reversed(created):
|
||||
group.delete()
|
||||
|
||||
|
||||
def query_names(filter_string):
|
||||
return sorted(SmartFilter.query_from_string(filter_string).distinct().values_list('name', flat=True))
|
||||
|
||||
|
||||
# --- Fact-based filter tests (require PostgreSQL for JSONField __contains) ---
|
||||
|
||||
|
||||
def test_fact_based_host_filter(fact_inventory):
|
||||
assert query_names('ansible_facts__ansible_system=Linux') == ['factHostA', 'factHostB']
|
||||
assert query_names('ansible_facts__ansible_distribution=CentOS') == ['factHostA']
|
||||
assert query_names('ansible_facts__ansible_distribution=Ubuntu') == ['factHostB']
|
||||
assert query_names('ansible_facts__ansible_system=Darwin') == ['factHostC']
|
||||
assert query_names('ansible_facts__ansible_system=Windows') == []
|
||||
|
||||
|
||||
def test_nested_fact_search(fact_inventory):
|
||||
assert query_names('ansible_facts__ansible_python__version__major=3') == ['factHostA', 'factHostB', 'factHostC']
|
||||
assert query_names('ansible_facts__ansible_python__version__minor=9') == ['factHostA']
|
||||
assert query_names('ansible_facts__ansible_python__version__minor=11') == ['factHostB']
|
||||
assert query_names('ansible_facts__ansible_env__HOME=/root') == ['factHostA']
|
||||
|
||||
|
||||
def test_list_fact_search(fact_inventory):
|
||||
assert query_names('ansible_facts__ansible_python__version_info[]=9') == ['factHostA']
|
||||
assert query_names('ansible_facts__ansible_python__version_info[]=11') == ['factHostB']
|
||||
assert query_names('ansible_facts__ansible_python__version_info[]=3') == ['factHostA', 'factHostB', 'factHostC']
|
||||
|
||||
|
||||
def test_fact_search_with_or(fact_inventory):
|
||||
assert query_names('ansible_facts__ansible_system=Linux or ansible_facts__ansible_system=Linux') == ['factHostA', 'factHostB']
|
||||
assert query_names('ansible_facts__ansible_system=Linux or ansible_facts__ansible_system=not_found') == ['factHostA', 'factHostB']
|
||||
assert query_names('ansible_facts__ansible_system=not_found or ansible_facts__ansible_system=not_found') == []
|
||||
assert query_names('ansible_facts__ansible_system=Linux or ansible_facts__ansible_system=Darwin') == ['factHostA', 'factHostB', 'factHostC']
|
||||
|
||||
|
||||
def test_fact_search_with_and(fact_inventory):
|
||||
assert query_names('ansible_facts__ansible_system=Linux and ansible_facts__ansible_system=Linux') == ['factHostA', 'factHostB']
|
||||
assert query_names('ansible_facts__ansible_system=Linux and ansible_facts__ansible_system=not_found') == []
|
||||
assert query_names('ansible_facts__ansible_system=Linux and ansible_facts__ansible_distribution=CentOS') == ['factHostA']
|
||||
|
||||
|
||||
def test_hybrid_fact_name_group_search(fact_inventory):
|
||||
assert query_names('name=factHostA or groups__name=factGroupB or ansible_facts__ansible_system=Linux') == ['factHostA', 'factHostB', 'factHostC']
|
||||
|
||||
assert query_names('name=factHostA or groups__name=factGroupA or ansible_facts__ansible_system=not_found') == ['factHostA', 'factHostC']
|
||||
|
||||
assert query_names('name=factHostA and groups__name=factGroupA and ansible_facts__ansible_system=not_found') == []
|
||||
|
||||
assert query_names('name=factHostA and groups__name=factGroupA and ansible_facts__ansible_system=Linux') == ['factHostA']
|
||||
|
||||
|
||||
def test_advanced_hybrid_with_parentheses(fact_inventory):
|
||||
assert query_names('name=factHostA or (groups__name=factGroupB and ansible_facts__ansible_system=not_found)') == ['factHostA']
|
||||
|
||||
assert query_names('name=not_found or (groups__name=factGroupB and ansible_facts__ansible_system=Linux)') == ['factHostB']
|
||||
|
||||
assert query_names('(name=factHostA or groups__name=factGroupB) and ansible_facts__ansible_system=not_found') == []
|
||||
|
||||
assert query_names('(name=factHostA or groups__name=factGroupB) and ansible_facts__ansible_system=Linux') == ['factHostA', 'factHostB']
|
||||
|
||||
assert query_names('(name=factHostC or groups__name=factGroupA) and ansible_facts__ansible_system=Darwin') == ['factHostC']
|
||||
|
||||
|
||||
# --- Smart inventory host resolution tests (require PostgreSQL for DISTINCT ON) ---
|
||||
|
||||
|
||||
def test_smart_inventory_hosts_by_name(fact_inventory, smart_inventory_factory):
|
||||
org = fact_inventory['org']
|
||||
smart_inv = smart_inventory_factory('smart-by-name', 'name=factHostA', org)
|
||||
hosts = sorted(smart_inv.hosts.values_list('name', flat=True))
|
||||
assert hosts == ['factHostA']
|
||||
|
||||
|
||||
def test_smart_inventory_hosts_by_group(fact_inventory, smart_inventory_factory):
|
||||
org = fact_inventory['org']
|
||||
smart_inv = smart_inventory_factory('smart-by-group', 'groups__name=factGroupA', org)
|
||||
hosts = sorted(smart_inv.hosts.values_list('name', flat=True))
|
||||
assert hosts == ['factHostA', 'factHostC']
|
||||
|
||||
|
||||
def test_smart_inventory_with_facts(fact_inventory, smart_inventory_factory):
|
||||
org = fact_inventory['org']
|
||||
smart_inv = smart_inventory_factory('fact-smart-inv', 'ansible_facts__ansible_system=Linux', org)
|
||||
hosts = sorted(smart_inv.hosts.values_list('name', flat=True))
|
||||
assert hosts == ['factHostA', 'factHostB']
|
||||
assert smart_inv.total_hosts == 2
|
||||
|
||||
|
||||
def test_smart_inventory_with_nested_facts(fact_inventory, smart_inventory_factory):
|
||||
org = fact_inventory['org']
|
||||
smart_inv = smart_inventory_factory(
|
||||
'nested-fact-smart-inv',
|
||||
'ansible_facts__ansible_distribution=CentOS and ansible_facts__ansible_python__version__minor=9',
|
||||
org,
|
||||
)
|
||||
hosts = list(smart_inv.hosts.values_list('name', flat=True))
|
||||
assert hosts == ['factHostA']
|
||||
|
||||
|
||||
def test_host_filter_is_organization_scoped(fact_inventory, smart_inventory_factory, host_factory):
|
||||
"""Smart inventory only includes hosts from its own organization."""
|
||||
org1 = fact_inventory['org']
|
||||
org2, _ = Organization.objects.get_or_create(name='smart-inv-other-org')
|
||||
inv2, _ = Inventory.objects.get_or_create(name='other-org-inv', organization=org2)
|
||||
Host.objects.filter(name='factHostA', inventory=inv2).delete()
|
||||
host_factory(name='factHostA', inventory=inv2)
|
||||
|
||||
smart_inv = smart_inventory_factory('scoped-smart', 'name=factHostA', org1)
|
||||
hosts = list(smart_inv.hosts.all())
|
||||
assert len(hosts) == 1
|
||||
assert hosts[0].inventory_id == fact_inventory['inv'].id
|
||||
|
||||
|
||||
def test_duplicate_hosts_deduplicated(smart_inventory_factory, host_factory):
|
||||
"""Same-name hosts across inventories in the same org yield only one smart inventory entry."""
|
||||
org, _ = Organization.objects.get_or_create(name='smart-inv-dedup-org')
|
||||
inv1, _ = Inventory.objects.get_or_create(name='dedup-inv1', organization=org)
|
||||
inv2, _ = Inventory.objects.get_or_create(name='dedup-inv2', organization=org)
|
||||
Host.objects.filter(name='dedup_host', inventory__in=[inv1, inv2]).delete()
|
||||
host1 = host_factory(name='dedup_host', inventory=inv1)
|
||||
host2 = host_factory(name='dedup_host', inventory=inv2)
|
||||
|
||||
smart_inv = smart_inventory_factory('dedup-smart', 'name=dedup_host', org)
|
||||
hosts = list(smart_inv.hosts.all())
|
||||
assert len(hosts) == 1
|
||||
assert hosts[0].id == min(host1.id, host2.id)
|
||||
|
||||
|
||||
def test_host_sources_original_inventory(fact_inventory, smart_inventory_factory):
|
||||
"""Hosts in a smart inventory still reference their source inventory."""
|
||||
org = fact_inventory['org']
|
||||
source_inv = fact_inventory['inv']
|
||||
|
||||
smart_inv = smart_inventory_factory('sources-original', 'name=factHostA', org)
|
||||
host = smart_inv.hosts.first()
|
||||
assert host.inventory_id == source_inv.id
|
||||
|
||||
|
||||
def test_host_updates_reflected_in_smart_inventory(fact_inventory, smart_inventory_factory, host_factory):
|
||||
"""Editing or deleting a host is immediately reflected in a smart inventory."""
|
||||
org = fact_inventory['org']
|
||||
inv = fact_inventory['inv']
|
||||
host = host_factory(name='mutable_host', inventory=inv)
|
||||
|
||||
smart_inv = smart_inventory_factory('updates-reflected', 'name=mutable_host', org)
|
||||
assert smart_inv.hosts.count() == 1
|
||||
|
||||
host.description = 'updated'
|
||||
host.save()
|
||||
assert smart_inv.hosts.first().description == 'updated'
|
||||
|
||||
host.delete()
|
||||
assert smart_inv.hosts.count() == 0
|
||||
|
||||
|
||||
def test_smart_inventory_duplicate_hosts_matching_group_names(fact_inventory, smart_inventory_factory, host_factory, group_factory):
|
||||
"""A host in multiple groups whose names match an icontains filter appears only once."""
|
||||
org = fact_inventory['org']
|
||||
inv = fact_inventory['inv']
|
||||
g1 = group_factory(name='dedup_another_group', inventory=inv)
|
||||
g2 = group_factory(name='dedup_yet_another_group', inventory=inv)
|
||||
host = host_factory(name='dedup_grouped_host', inventory=inv)
|
||||
g1.hosts.add(host)
|
||||
g2.hosts.add(host)
|
||||
|
||||
smart_inv = smart_inventory_factory('group-dedup-smart', 'groups__name__icontains=dedup_another', org)
|
||||
assert smart_inv.hosts.count() == 1
|
||||
|
||||
|
||||
def test_smart_inventory_computed_fields(fact_inventory, smart_inventory_factory):
|
||||
"""Smart inventory total_hosts and related computed fields are accurate."""
|
||||
org = fact_inventory['org']
|
||||
smart_inv = smart_inventory_factory('computed-fields', 'name=factHostA or name=factHostB', org)
|
||||
assert smart_inv.total_hosts == 2
|
||||
assert smart_inv.total_groups == 0
|
||||
assert smart_inv.total_inventory_sources == 0
|
||||
assert smart_inv.has_inventory_sources is False
|
||||
|
||||
|
||||
def test_smart_inventory_matches_host_filter(fact_inventory, smart_inventory_factory):
|
||||
"""Smart inventory hosts should match the equivalent SmartFilter query."""
|
||||
org = fact_inventory['org']
|
||||
host_filter = 'groups__name=factGroupA or groups__name=factGroupB'
|
||||
|
||||
smart_inv = smart_inventory_factory('match-filter', host_filter, org)
|
||||
smart_names = sorted(smart_inv.hosts.values_list('name', flat=True))
|
||||
filter_names = sorted(SmartFilter.query_from_string(host_filter).distinct().values_list('name', flat=True))
|
||||
assert smart_names == filter_names
|
||||
271
awx/main/tests/unit/analytics/test_core_ship.py
Normal file
271
awx/main/tests/unit/analytics/test_core_ship.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright (c) 2026 Ansible, Inc.
|
||||
# All Rights Reserved.
|
||||
|
||||
"""Tests for analytics ship() function with mTLS authentication."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock
|
||||
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from awx.main.analytics.core import ship, _get_cert_upload_url
|
||||
|
||||
|
||||
class TestGetCertUploadUrl:
|
||||
"""Test _get_cert_upload_url() helper function."""
|
||||
|
||||
def test_adds_cert_subdomain(self):
|
||||
"""Test that 'cert.' is added to hostname."""
|
||||
url = 'https://analytics.example.com/api/ingress/v1/upload'
|
||||
result = _get_cert_upload_url(url)
|
||||
assert result == 'https://cert.analytics.example.com/api/ingress/v1/upload'
|
||||
|
||||
def test_preserves_existing_cert_subdomain(self):
|
||||
"""Test that existing 'cert.' subdomain is preserved."""
|
||||
url = 'https://cert.analytics.example.com/api/ingress/v1/upload'
|
||||
result = _get_cert_upload_url(url)
|
||||
assert result == 'https://cert.analytics.example.com/api/ingress/v1/upload'
|
||||
|
||||
|
||||
class TestShipMTLS:
|
||||
"""Test ship() function's mTLS authentication path."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Create a temporary tarball for testing."""
|
||||
self.temp_file = tempfile.NamedTemporaryFile(mode='wb', suffix='.tar.gz', delete=False)
|
||||
self.temp_file.write(b'test tarball content')
|
||||
self.temp_file.close()
|
||||
self.tarball_path = self.temp_file.name
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up temporary tarball."""
|
||||
if os.path.exists(self.tarball_path):
|
||||
os.unlink(self.tarball_path)
|
||||
|
||||
@override_settings(
|
||||
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
|
||||
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
|
||||
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
|
||||
REDHAT_USERNAME='test_user',
|
||||
REDHAT_PASSWORD='test_pass', # NOSONAR
|
||||
AWX_TASK_ENV={},
|
||||
)
|
||||
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
|
||||
@mock.patch('awx.main.analytics.core._temp_cert_files')
|
||||
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
|
||||
@mock.patch('awx.main.analytics.core.requests.Session')
|
||||
def test_ship_with_mtls_success(self, mock_session_class, mock_get_cert, mock_temp_files, mock_headers):
|
||||
"""Test successful upload with mTLS certificate authentication."""
|
||||
# Mock headers to avoid database access
|
||||
mock_headers.return_value = {'Content-Type': 'application/json'}
|
||||
|
||||
# Mock certificate retrieval
|
||||
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
|
||||
|
||||
# Mock temp files context manager
|
||||
mock_temp_files.return_value.__enter__.return_value = ('/tmp/cert.pem', '/tmp/key.pem')
|
||||
mock_temp_files.return_value.__exit__.return_value = None
|
||||
|
||||
# Mock successful mTLS response
|
||||
mock_response = mock.Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_session = mock.Mock()
|
||||
mock_session.headers = {}
|
||||
mock_session.post.return_value = mock_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
result = ship(self.tarball_path)
|
||||
|
||||
assert result is True
|
||||
mock_get_cert.assert_called_once()
|
||||
mock_temp_files.assert_called_once_with('cert-pem-data', 'key-pem-data')
|
||||
mock_session.post.assert_called_once()
|
||||
|
||||
# Verify cert URL is used (cert. subdomain added)
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[0][0] == 'https://cert.analytics.example.com/api/ingress/v1/upload'
|
||||
|
||||
# Verify mTLS cert was used
|
||||
call_kwargs = call_args[1]
|
||||
assert call_kwargs['cert'] == ('/tmp/cert.pem', '/tmp/key.pem')
|
||||
|
||||
@override_settings(
|
||||
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
|
||||
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
|
||||
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
|
||||
REDHAT_USERNAME='test_user',
|
||||
REDHAT_PASSWORD='test_pass', # NOSONAR
|
||||
AWX_TASK_ENV={},
|
||||
)
|
||||
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
|
||||
@mock.patch('awx.main.analytics.core.OIDCClient')
|
||||
@mock.patch('awx.main.analytics.core._temp_cert_files')
|
||||
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
|
||||
@mock.patch('awx.main.analytics.core.requests.Session')
|
||||
def test_ship_mtls_fallback_to_oidc_on_cert_failure(self, mock_session_class, mock_get_cert, mock_temp_files, mock_oidc_client, mock_headers):
|
||||
"""Test fallback to OIDC auth when mTLS cert authentication fails."""
|
||||
# Mock headers to avoid database access
|
||||
mock_headers.return_value = {'Content-Type': 'application/json'}
|
||||
|
||||
# Mock certificate retrieval
|
||||
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
|
||||
|
||||
# Mock temp files context manager
|
||||
mock_temp_files.return_value.__enter__.return_value = ('/tmp/cert.pem', '/tmp/key.pem')
|
||||
mock_temp_files.return_value.__exit__.return_value = None
|
||||
|
||||
# Mock failed mTLS response (401 Unauthorized)
|
||||
mock_mtls_response = mock.Mock()
|
||||
mock_mtls_response.status_code = 401
|
||||
mock_session = mock.Mock()
|
||||
mock_session.headers = {}
|
||||
mock_session.post.return_value = mock_mtls_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Mock successful OIDC response
|
||||
mock_oidc_response = mock.Mock()
|
||||
mock_oidc_response.status_code = 200
|
||||
mock_oidc_instance = mock.Mock()
|
||||
mock_oidc_instance.make_request.return_value = mock_oidc_response
|
||||
mock_oidc_client.return_value = mock_oidc_instance
|
||||
|
||||
result = ship(self.tarball_path)
|
||||
|
||||
assert result is True
|
||||
# Both mTLS and OIDC should be attempted
|
||||
assert mock_session.post.call_count == 1
|
||||
mock_oidc_instance.make_request.assert_called_once()
|
||||
|
||||
# Verify mTLS used cert URL
|
||||
mtls_call_args = mock_session.post.call_args
|
||||
assert mtls_call_args[0][0] == 'https://cert.analytics.example.com/api/ingress/v1/upload'
|
||||
|
||||
# Verify OIDC used original URL
|
||||
oidc_call_args = mock_oidc_instance.make_request.call_args
|
||||
assert oidc_call_args[0][1] == 'https://analytics.example.com/api/ingress/v1/upload'
|
||||
|
||||
@override_settings(
|
||||
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
|
||||
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
|
||||
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
|
||||
REDHAT_USERNAME='test_user',
|
||||
REDHAT_PASSWORD='test_pass', # NOSONAR
|
||||
AWX_TASK_ENV={},
|
||||
)
|
||||
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
|
||||
@mock.patch('awx.main.analytics.core._temp_cert_files')
|
||||
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
|
||||
@mock.patch('awx.main.analytics.core.OIDCClient')
|
||||
@mock.patch('awx.main.analytics.core.requests.Session')
|
||||
def test_ship_mtls_exception_fallback_to_oidc(self, mock_session_class, mock_oidc_client, mock_get_cert, mock_temp_files, mock_headers):
|
||||
"""Test fallback to OIDC auth when mTLS raises an exception."""
|
||||
# Mock headers to avoid database access
|
||||
mock_headers.return_value = {'Content-Type': 'application/json'}
|
||||
|
||||
# Mock certificate retrieval
|
||||
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
|
||||
|
||||
# Mock temp files context manager raising an exception
|
||||
mock_temp_files.return_value.__enter__.side_effect = OSError('Temp file creation failed')
|
||||
|
||||
# Mock successful OIDC response
|
||||
mock_oidc_response = mock.Mock()
|
||||
mock_oidc_response.status_code = 200
|
||||
mock_oidc_instance = mock.Mock()
|
||||
mock_oidc_instance.make_request.return_value = mock_oidc_response
|
||||
mock_oidc_client.return_value = mock_oidc_instance
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.headers = {}
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
result = ship(self.tarball_path)
|
||||
|
||||
assert result is True
|
||||
# mTLS should fail, OIDC should succeed
|
||||
mock_oidc_instance.make_request.assert_called_once()
|
||||
|
||||
@override_settings(
|
||||
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
|
||||
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
|
||||
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
|
||||
REDHAT_USERNAME='test_user',
|
||||
REDHAT_PASSWORD='test_pass', # NOSONAR
|
||||
AWX_TASK_ENV={},
|
||||
)
|
||||
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
|
||||
@mock.patch('awx.main.analytics.core.OIDCClient')
|
||||
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
|
||||
@mock.patch('awx.main.analytics.core.requests.Session')
|
||||
def test_ship_no_certificate_available(self, mock_session_class, mock_get_cert, mock_oidc_client, mock_headers):
|
||||
"""Test ship() when no Candlepin certificate is available."""
|
||||
# Mock headers to avoid database access
|
||||
mock_headers.return_value = {'Content-Type': 'application/json'}
|
||||
|
||||
# Mock no certificate available
|
||||
mock_get_cert.return_value = (None, None)
|
||||
|
||||
# Mock successful OIDC response
|
||||
mock_oidc_response = mock.Mock()
|
||||
mock_oidc_response.status_code = 200
|
||||
mock_oidc_instance = mock.Mock()
|
||||
mock_oidc_instance.make_request.return_value = mock_oidc_response
|
||||
mock_oidc_client.return_value = mock_oidc_instance
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.headers = {}
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
result = ship(self.tarball_path)
|
||||
|
||||
assert result is True
|
||||
# Should skip mTLS and go straight to OIDC
|
||||
mock_oidc_instance.make_request.assert_called_once()
|
||||
|
||||
@override_settings(
|
||||
AUTOMATION_ANALYTICS_URL='https://analytics.example.com/api/ingress/v1/upload',
|
||||
INSIGHTS_AGENT_MIME='application/vnd.redhat.tower.analytics+tgz',
|
||||
INSIGHTS_CERT_PATH='/etc/pki/tls/certs/ca-bundle.crt',
|
||||
REDHAT_USERNAME='test_user',
|
||||
REDHAT_PASSWORD='test_pass', # NOSONAR
|
||||
AWX_TASK_ENV={},
|
||||
)
|
||||
@mock.patch('awx.main.analytics.core.get_awx_http_client_headers')
|
||||
@mock.patch('awx.main.analytics.core.OIDCClient')
|
||||
@mock.patch('awx.main.analytics.core._temp_cert_files')
|
||||
@mock.patch('awx.main.analytics.core.get_or_generate_candlepin_certificate')
|
||||
@mock.patch('awx.main.analytics.core.requests.Session')
|
||||
def test_ship_both_auth_methods_fail(self, mock_session_class, mock_get_cert, mock_temp_files, mock_oidc_client, mock_headers):
|
||||
"""Test ship() when both mTLS and OIDC authentication fail."""
|
||||
# Mock headers to avoid database access
|
||||
mock_headers.return_value = {'Content-Type': 'application/json'}
|
||||
|
||||
# Mock certificate retrieval
|
||||
mock_get_cert.return_value = ('cert-pem-data', 'key-pem-data')
|
||||
|
||||
# Mock temp files context manager
|
||||
mock_temp_files.return_value.__enter__.return_value = ('/tmp/cert.pem', '/tmp/key.pem')
|
||||
mock_temp_files.return_value.__exit__.return_value = None
|
||||
|
||||
# Mock failed mTLS response
|
||||
mock_mtls_response = mock.Mock()
|
||||
mock_mtls_response.status_code = 401
|
||||
mock_session = mock.Mock()
|
||||
mock_session.headers = {}
|
||||
mock_session.post.return_value = mock_mtls_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Mock failed OIDC response
|
||||
mock_oidc_response = mock.Mock()
|
||||
mock_oidc_response.status_code = 403
|
||||
mock_oidc_response.text = 'Forbidden'
|
||||
mock_oidc_instance = mock.Mock()
|
||||
mock_oidc_instance.make_request.return_value = mock_oidc_response
|
||||
mock_oidc_client.return_value = mock_oidc_instance
|
||||
|
||||
result = ship(self.tarball_path)
|
||||
|
||||
assert result is False
|
||||
mock_session.post.assert_called_once()
|
||||
mock_oidc_instance.make_request.assert_called_once()
|
||||
310
awx/main/tests/unit/management/commands/test_candlepin_cert.py
Normal file
310
awx/main/tests/unit/management/commands/test_candlepin_cert.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# Copyright (c) 2026 Ansible, Inc.
|
||||
# All Rights Reserved.
|
||||
|
||||
"""Tests for candlepin_cert management command."""
|
||||
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.core.management import call_command
|
||||
from django.test.utils import override_settings
|
||||
|
||||
|
||||
class TestCandlepinCertCommand:
|
||||
"""Tests for candlepin_cert management command."""
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_registration_to_db')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.resolve_registration_credentials')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
@override_settings(
|
||||
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
|
||||
AWX_ANALYTICS_CANDLEPIN_CA=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
|
||||
)
|
||||
def test_register_success(self, mock_fetch_cert, mock_resolve_creds, mock_client_class, mock_save_reg):
|
||||
"""Test successful registration."""
|
||||
# No existing cert
|
||||
mock_fetch_cert.return_value = (None, None, None)
|
||||
|
||||
# Valid credentials
|
||||
mock_resolve_creds.return_value = ('test_user', 'test_pass', 'test_org', 'install-uuid', None)
|
||||
|
||||
# Mock successful registration
|
||||
mock_client = mock.Mock()
|
||||
mock_client.register_consumer.return_value = ('cert-pem', 'key-pem', 'consumer-uuid')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock successful save
|
||||
mock_save_reg.return_value = True
|
||||
|
||||
out = StringIO()
|
||||
call_command('candlepin_cert', 'register', stdout=out, stderr=StringIO())
|
||||
|
||||
output = out.getvalue()
|
||||
assert 'Registered successfully' in output
|
||||
assert 'consumer-uuid' in output
|
||||
|
||||
mock_client.register_consumer.assert_called_once_with('test_user', 'test_pass', 'test_org', install_uuid='install-uuid')
|
||||
mock_save_reg.assert_called_once_with('cert-pem', 'key-pem', 'consumer-uuid')
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
def test_register_already_registered_without_force(self, mock_fetch_cert):
|
||||
"""Test registration fails when cert already exists and --force not provided."""
|
||||
# Existing cert
|
||||
mock_fetch_cert.return_value = ('existing-cert', 'existing-key', 'existing-uuid')
|
||||
|
||||
out = StringIO()
|
||||
call_command('candlepin_cert', 'register', stdout=out, stderr=StringIO())
|
||||
|
||||
output = out.getvalue()
|
||||
assert 'already stored' in output
|
||||
assert '--force' in output
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_registration_to_db')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.resolve_registration_credentials')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
@override_settings(
|
||||
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
|
||||
AWX_ANALYTICS_CANDLEPIN_CA=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
|
||||
)
|
||||
def test_register_with_force_flag(self, mock_fetch_cert, mock_resolve_creds, mock_client_class, mock_save_reg):
|
||||
"""Test registration succeeds with --force even when cert exists."""
|
||||
# Existing cert
|
||||
mock_fetch_cert.return_value = ('existing-cert', 'existing-key', 'existing-uuid')
|
||||
|
||||
# Valid credentials
|
||||
mock_resolve_creds.return_value = ('test_user', 'test_pass', 'test_org', 'install-uuid', None)
|
||||
|
||||
# Mock successful registration
|
||||
mock_client = mock.Mock()
|
||||
mock_client.register_consumer.return_value = ('new-cert-pem', 'new-key-pem', 'new-consumer-uuid')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock successful save
|
||||
mock_save_reg.return_value = True
|
||||
|
||||
out = StringIO()
|
||||
call_command('candlepin_cert', 'register', '--force', stdout=out, stderr=StringIO())
|
||||
|
||||
output = out.getvalue()
|
||||
assert 'Registered successfully' in output
|
||||
|
||||
mock_client.register_consumer.assert_called_once()
|
||||
mock_save_reg.assert_called_once_with('new-cert-pem', 'new-key-pem', 'new-consumer-uuid')
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.resolve_registration_credentials')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
def test_register_missing_credentials(self, mock_fetch_cert, mock_resolve_creds):
|
||||
"""Test registration fails when credentials are missing."""
|
||||
mock_fetch_cert.return_value = (None, None, None)
|
||||
|
||||
# Missing credentials
|
||||
mock_resolve_creds.return_value = (None, None, None, None, ['username', 'password'])
|
||||
|
||||
err = StringIO()
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
call_command('candlepin_cert', 'register', stderr=err)
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
error_output = err.getvalue()
|
||||
assert 'Missing required value' in error_output
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_cert_to_db')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
@override_settings(
|
||||
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
|
||||
AWX_ANALYTICS_CANDLEPIN_CA=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
|
||||
)
|
||||
def test_renew_success(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class, mock_save_cert):
|
||||
"""Test successful certificate renewal."""
|
||||
# Existing cert
|
||||
mock_fetch_cert.return_value = ('old-cert', 'old-key', 'consumer-uuid')
|
||||
|
||||
# Parse cert returns metadata
|
||||
mock_parse_cert.side_effect = [
|
||||
{'serial': '123', 'cn': 'test', 'not_after': '2026-06-01', 'days_remaining': 10}, # Current cert
|
||||
{'serial': '456', 'cn': 'test', 'not_after': '2027-06-01', 'days_remaining': 365}, # Renewed cert
|
||||
]
|
||||
|
||||
# Renewal needed
|
||||
mock_needs_renewal.return_value = True
|
||||
|
||||
# Mock successful check-in and renewal
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = True
|
||||
mock_client.regenerate_cert.return_value = ('new-cert', 'new-key')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_save_cert.return_value = True
|
||||
|
||||
out = StringIO()
|
||||
call_command('candlepin_cert', 'renew', stdout=out, stderr=StringIO())
|
||||
|
||||
output = out.getvalue()
|
||||
assert 'Check-in successful' in output
|
||||
assert 'Certificate renewed successfully' in output
|
||||
assert 'saved to database' in output
|
||||
|
||||
mock_client.checkin.assert_called_once_with('consumer-uuid', 'old-cert', 'old-key')
|
||||
mock_client.regenerate_cert.assert_called_once()
|
||||
mock_save_cert.assert_called_once_with('new-cert', 'new-key')
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
def test_renew_no_cert_in_db(self, mock_fetch_cert):
|
||||
"""Test renew fails when no certificate exists in database."""
|
||||
mock_fetch_cert.return_value = (None, None, None)
|
||||
|
||||
err = StringIO()
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
call_command('candlepin_cert', 'renew', stderr=err)
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
error_output = err.getvalue()
|
||||
assert 'No Candlepin identity certificate found' in error_output
|
||||
assert 'Run the register subcommand first' in error_output
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
@override_settings(
|
||||
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
|
||||
AWX_ANALYTICS_CANDLEPIN_CA=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
|
||||
)
|
||||
def test_renew_not_needed(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class):
|
||||
"""Test renew when certificate is still valid and renewal not needed."""
|
||||
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
|
||||
|
||||
# Parse cert returns healthy cert
|
||||
mock_parse_cert.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 200}
|
||||
|
||||
# Renewal not needed
|
||||
mock_needs_renewal.return_value = False
|
||||
|
||||
# Mock successful check-in
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = True
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
out = StringIO()
|
||||
call_command('candlepin_cert', 'renew', stdout=out, stderr=StringIO())
|
||||
|
||||
output = out.getvalue()
|
||||
assert 'Check-in successful' in output
|
||||
assert 'No renewal needed' in output
|
||||
|
||||
mock_client.checkin.assert_called_once()
|
||||
mock_client.regenerate_cert.assert_not_called()
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._save_candlepin_cert_to_db')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
@override_settings(
|
||||
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
|
||||
AWX_ANALYTICS_CANDLEPIN_CA=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
|
||||
)
|
||||
def test_renew_with_force_flag(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class, mock_save_cert):
|
||||
"""Test renew --force renews even when not needed."""
|
||||
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
|
||||
|
||||
# Parse cert
|
||||
mock_parse_cert.side_effect = [
|
||||
{'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 200}, # Current cert (healthy)
|
||||
{'serial': '456', 'cn': 'test', 'not_after': '2027-06-01', 'days_remaining': 365}, # New cert
|
||||
]
|
||||
|
||||
# Would not need renewal without --force
|
||||
mock_needs_renewal.return_value = False
|
||||
|
||||
# Mock successful operations
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = True
|
||||
mock_client.regenerate_cert.return_value = ('new-cert', 'new-key')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_save_cert.return_value = True
|
||||
|
||||
out = StringIO()
|
||||
call_command('candlepin_cert', 'renew', '--force', stdout=out, stderr=StringIO())
|
||||
|
||||
output = out.getvalue()
|
||||
assert 'forced via --force' in output
|
||||
assert 'Certificate renewed successfully' in output
|
||||
|
||||
mock_client.regenerate_cert.assert_called_once()
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
@override_settings(
|
||||
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
|
||||
AWX_ANALYTICS_CANDLEPIN_CA=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
|
||||
)
|
||||
def test_renew_checkin_failure(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class):
|
||||
"""Test renew handles check-in failure gracefully."""
|
||||
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
|
||||
|
||||
mock_parse_cert.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 100}
|
||||
mock_needs_renewal.return_value = False # Not needed for renewal, just testing check-in failure
|
||||
|
||||
# Mock failed check-in
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = False
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
err = StringIO()
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
call_command('candlepin_cert', 'renew', stderr=err)
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
error_output = err.getvalue()
|
||||
assert 'Check-in with Candlepin failed' in error_output
|
||||
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.CandlepinClient')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.parse_cert')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert.needs_renewal')
|
||||
@mock.patch('awx.main.management.commands.candlepin_cert._fetch_candlepin_cert_from_db')
|
||||
@override_settings(
|
||||
AWX_ANALYTICS_CANDLEPIN_URL='https://test.example.com',
|
||||
AWX_ANALYTICS_CANDLEPIN_CA=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_PROXY_URL=None,
|
||||
AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS=90,
|
||||
)
|
||||
def test_renew_regenerate_cert_failure(self, mock_fetch_cert, mock_needs_renewal, mock_parse_cert, mock_client_class):
|
||||
"""Test renew handles certificate regeneration failure."""
|
||||
mock_fetch_cert.return_value = ('cert', 'key', 'consumer-uuid')
|
||||
|
||||
mock_parse_cert.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2026-06-01', 'days_remaining': 10}
|
||||
mock_needs_renewal.return_value = True
|
||||
|
||||
# Mock successful check-in but failed regeneration
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = True
|
||||
mock_client.regenerate_cert.side_effect = Exception('Certificate regeneration failed')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
err = StringIO()
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
call_command('candlepin_cert', 'renew', stderr=err)
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
error_output = err.getvalue()
|
||||
assert 'Certificate renewal failed' in error_output
|
||||
35
awx/main/tests/unit/management/commands/test_check_db.py
Normal file
35
awx/main/tests/unit/management/commands/test_check_db.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
from django.core.management.base import CommandError
|
||||
|
||||
from awx.main.management.commands.check_db import Command
|
||||
|
||||
|
||||
def test_check_db_command_success(mocker):
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.fetchone.return_value = ['PostgreSQL 12.8 on x86_64-pc-linux-gnu, compiled by gcc (GCC) 9.3.0, 64-bit']
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mocker.patch('awx.main.management.commands.check_db.connection', mock_connection)
|
||||
mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=None)
|
||||
|
||||
command = Command()
|
||||
result = command.handle()
|
||||
|
||||
assert 'Database Version:' in result
|
||||
mock_cursor.execute.assert_called_once_with('SELECT version()')
|
||||
|
||||
|
||||
def test_check_db_command_version_violations(mocker):
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.fetchone.return_value = ['PostgreSQL 11.0 on x86_64-pc-linux-gnu']
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mocker.patch('awx.main.management.commands.check_db.connection', mock_connection)
|
||||
violation_msg = "At a minimum, postgres version 12 is required, found 11\n"
|
||||
mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=violation_msg)
|
||||
|
||||
command = Command()
|
||||
with pytest.raises(CommandError) as exc_info:
|
||||
command.handle()
|
||||
|
||||
assert str(exc_info.value) == violation_msg
|
||||
@@ -112,7 +112,9 @@ def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
|
||||
os.remove(os.path.join(fact_cache_dir, hosts[1].name))
|
||||
|
||||
hosts_qs = mock.MagicMock()
|
||||
hosts_qs.filter.return_value.order_by.return_value.iterator.return_value = iter(hosts)
|
||||
# The new code calls host_qs.filter(name__in=...).select_related('inventory')
|
||||
# Only hosts[1] needs clearing (its file was removed), so return just that host
|
||||
hosts_qs.filter.return_value.select_related.return_value = [hosts[1]]
|
||||
|
||||
finish_fact_cache(hosts_qs, artifacts_dir=artifacts_dir, inventory_id=inventory_id)
|
||||
|
||||
@@ -145,10 +147,8 @@ def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
|
||||
os.utime(filepath, (new_modification_time, new_modification_time))
|
||||
|
||||
hosts_qs = mock.MagicMock()
|
||||
hosts_qs.filter.return_value.order_by.return_value.iterator.return_value = iter(hosts)
|
||||
|
||||
finish_fact_cache(hosts_qs, artifacts_dir=artifacts_dir, inventory_id=inventory_id)
|
||||
|
||||
# Invalid JSON should be skipped — no hosts updated
|
||||
updated_hosts = bulk_update.call_args[0][1]
|
||||
assert updated_hosts == []
|
||||
# Invalid JSON should be skipped — no hosts updated, bulk_update never called
|
||||
bulk_update.assert_not_called()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory
|
||||
from awx.main.constants import JOB_VARIABLE_PREFIXES
|
||||
from awx.main.utils.common import get_job_variable_prefixes
|
||||
|
||||
|
||||
def test_incorrectly_formatted_variables():
|
||||
@@ -50,7 +50,7 @@ class TestMetaVars:
|
||||
maker = User(username='joe', pk=47, id=47)
|
||||
inv = Inventory(name='example-inv', id=45)
|
||||
result_hash = {}
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
result_hash['{}_job_id'.format(name)] = 42
|
||||
result_hash['{}_job_launch_type'.format(name)] = 'manual'
|
||||
result_hash['{}_user_name'.format(name)] = 'joe'
|
||||
@@ -75,8 +75,48 @@ class TestMetaVars:
|
||||
project=Project(name='jobs-sync', scm_revision='12345444'),
|
||||
job_template=JobTemplate(name='jobs-jt', id=92, pk=92),
|
||||
).awx_meta_vars()
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
assert data['{}_project_revision'.format(name)] == '12345444'
|
||||
assert '{}_job_template_id'.format(name) in data
|
||||
assert data['{}_job_template_id'.format(name)] == 92
|
||||
assert data['{}_job_template_name'.format(name)] == 'jobs-jt'
|
||||
|
||||
|
||||
class TestGetJobVariablePrefixes:
|
||||
"""Tests for the get_job_variable_prefixes() helper function."""
|
||||
|
||||
def test_default_returns_both(self):
|
||||
from django.conf import settings
|
||||
|
||||
with mock.patch.object(settings, 'INCLUDE_DEPRECATED_AWX_VAR_PREFIX', True, create=True):
|
||||
assert get_job_variable_prefixes() == ['awx', 'tower']
|
||||
|
||||
def test_disabled_returns_tower_only(self):
|
||||
from django.conf import settings
|
||||
|
||||
with mock.patch.object(settings, 'INCLUDE_DEPRECATED_AWX_VAR_PREFIX', False, create=True):
|
||||
assert get_job_variable_prefixes() == ['tower']
|
||||
|
||||
def test_fallback_when_setting_not_available(self):
|
||||
"""When setting is not available, falls back to both prefixes for backward compatibility."""
|
||||
fake_settings = mock.MagicMock(spec=[])
|
||||
with mock.patch('django.conf.settings', fake_settings):
|
||||
assert get_job_variable_prefixes() == ['awx', 'tower']
|
||||
|
||||
def test_job_metavars_both_prefixes(self):
|
||||
"""With INCLUDE_DEPRECATED_AWX_VAR_PREFIX=True, both awx_ and tower_ variables."""
|
||||
from django.conf import settings
|
||||
|
||||
with mock.patch.object(settings, 'INCLUDE_DEPRECATED_AWX_VAR_PREFIX', True, create=True):
|
||||
data = Job(name='fake-job', pk=1, id=1, launch_type='manual').awx_meta_vars()
|
||||
assert 'awx_job_id' in data
|
||||
assert 'tower_job_id' in data
|
||||
|
||||
def test_job_metavars_tower_only(self):
|
||||
"""With INCLUDE_DEPRECATED_AWX_VAR_PREFIX=False, only tower_ prefixed variables."""
|
||||
from django.conf import settings
|
||||
|
||||
with mock.patch.object(settings, 'INCLUDE_DEPRECATED_AWX_VAR_PREFIX', False, create=True):
|
||||
data = Job(name='fake-job', pk=1, id=1, launch_type='manual').awx_meta_vars()
|
||||
assert 'tower_job_id' in data
|
||||
assert 'awx_job_id' not in data
|
||||
|
||||
@@ -10,8 +10,8 @@ def test_send_messages():
|
||||
with mock.patch('awx.main.notifications.grafana_backend.requests') as requests_mock:
|
||||
requests_mock.post.return_value.status_code = 200
|
||||
m = {}
|
||||
m['started'] = dt.datetime.utcfromtimestamp(60).isoformat()
|
||||
m['finished'] = dt.datetime.utcfromtimestamp(120).isoformat()
|
||||
m['started'] = dt.datetime.fromtimestamp(60, tz=dt.timezone.utc).isoformat()
|
||||
m['finished'] = dt.datetime.fromtimestamp(120, tz=dt.timezone.utc).isoformat()
|
||||
m['subject'] = "test subject"
|
||||
backend = grafana_backend.GrafanaBackend("testapikey", dashboardId='', panelId='')
|
||||
message = EmailMessage(
|
||||
@@ -40,8 +40,8 @@ def test_send_messages_with_no_verify_ssl():
|
||||
with mock.patch('awx.main.notifications.grafana_backend.requests') as requests_mock:
|
||||
requests_mock.post.return_value.status_code = 200
|
||||
m = {}
|
||||
m['started'] = dt.datetime.utcfromtimestamp(60).isoformat()
|
||||
m['finished'] = dt.datetime.utcfromtimestamp(120).isoformat()
|
||||
m['started'] = dt.datetime.fromtimestamp(60, tz=dt.timezone.utc).isoformat()
|
||||
m['finished'] = dt.datetime.fromtimestamp(120, tz=dt.timezone.utc).isoformat()
|
||||
m['subject'] = "test subject"
|
||||
backend = grafana_backend.GrafanaBackend("testapikey", dashboardId='', panelId='', grafana_no_verify_ssl=True)
|
||||
message = EmailMessage(
|
||||
@@ -71,8 +71,8 @@ def test_send_messages_with_dashboardid(dashboardId):
|
||||
with mock.patch('awx.main.notifications.grafana_backend.requests') as requests_mock:
|
||||
requests_mock.post.return_value.status_code = 200
|
||||
m = {}
|
||||
m['started'] = dt.datetime.utcfromtimestamp(60).isoformat()
|
||||
m['finished'] = dt.datetime.utcfromtimestamp(120).isoformat()
|
||||
m['started'] = dt.datetime.fromtimestamp(60, tz=dt.timezone.utc).isoformat()
|
||||
m['finished'] = dt.datetime.fromtimestamp(120, tz=dt.timezone.utc).isoformat()
|
||||
m['subject'] = "test subject"
|
||||
backend = grafana_backend.GrafanaBackend("testapikey", dashboardId=dashboardId, panelId='')
|
||||
message = EmailMessage(
|
||||
@@ -102,8 +102,8 @@ def test_send_messages_with_panelid(panelId):
|
||||
with mock.patch('awx.main.notifications.grafana_backend.requests') as requests_mock:
|
||||
requests_mock.post.return_value.status_code = 200
|
||||
m = {}
|
||||
m['started'] = dt.datetime.utcfromtimestamp(60).isoformat()
|
||||
m['finished'] = dt.datetime.utcfromtimestamp(120).isoformat()
|
||||
m['started'] = dt.datetime.fromtimestamp(60, tz=dt.timezone.utc).isoformat()
|
||||
m['finished'] = dt.datetime.fromtimestamp(120, tz=dt.timezone.utc).isoformat()
|
||||
m['subject'] = "test subject"
|
||||
backend = grafana_backend.GrafanaBackend("testapikey", dashboardId='', panelId=panelId)
|
||||
message = EmailMessage(
|
||||
@@ -132,8 +132,8 @@ def test_send_messages_with_bothids():
|
||||
with mock.patch('awx.main.notifications.grafana_backend.requests') as requests_mock:
|
||||
requests_mock.post.return_value.status_code = 200
|
||||
m = {}
|
||||
m['started'] = dt.datetime.utcfromtimestamp(60).isoformat()
|
||||
m['finished'] = dt.datetime.utcfromtimestamp(120).isoformat()
|
||||
m['started'] = dt.datetime.fromtimestamp(60, tz=dt.timezone.utc).isoformat()
|
||||
m['finished'] = dt.datetime.fromtimestamp(120, tz=dt.timezone.utc).isoformat()
|
||||
m['subject'] = "test subject"
|
||||
backend = grafana_backend.GrafanaBackend("testapikey", dashboardId='42', panelId='42')
|
||||
message = EmailMessage(
|
||||
@@ -162,8 +162,8 @@ def test_send_messages_with_emptyids():
|
||||
with mock.patch('awx.main.notifications.grafana_backend.requests') as requests_mock:
|
||||
requests_mock.post.return_value.status_code = 200
|
||||
m = {}
|
||||
m['started'] = dt.datetime.utcfromtimestamp(60).isoformat()
|
||||
m['finished'] = dt.datetime.utcfromtimestamp(120).isoformat()
|
||||
m['started'] = dt.datetime.fromtimestamp(60, tz=dt.timezone.utc).isoformat()
|
||||
m['finished'] = dt.datetime.fromtimestamp(120, tz=dt.timezone.utc).isoformat()
|
||||
m['subject'] = "test subject"
|
||||
backend = grafana_backend.GrafanaBackend("testapikey", dashboardId='', panelId='')
|
||||
message = EmailMessage(
|
||||
@@ -192,8 +192,8 @@ def test_send_messages_with_tags():
|
||||
with mock.patch('awx.main.notifications.grafana_backend.requests') as requests_mock:
|
||||
requests_mock.post.return_value.status_code = 200
|
||||
m = {}
|
||||
m['started'] = dt.datetime.utcfromtimestamp(60).isoformat()
|
||||
m['finished'] = dt.datetime.utcfromtimestamp(120).isoformat()
|
||||
m['started'] = dt.datetime.fromtimestamp(60, tz=dt.timezone.utc).isoformat()
|
||||
m['finished'] = dt.datetime.fromtimestamp(120, tz=dt.timezone.utc).isoformat()
|
||||
m['subject'] = "test subject"
|
||||
backend = grafana_backend.GrafanaBackend("testapikey", dashboardId='', panelId='', annotation_tags=["ansible"])
|
||||
message = EmailMessage(
|
||||
|
||||
@@ -83,11 +83,15 @@ def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, pri
|
||||
host1 = mock.MagicMock(spec=Host, id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
|
||||
host2 = mock.MagicMock(spec=Host, id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
|
||||
|
||||
# Mock hosts queryset
|
||||
# Mock hosts queryset — must support .only().filter().order_by().iterator() chain
|
||||
hosts = [host1, host2]
|
||||
qs_hosts = mock.MagicMock(spec=QuerySet)
|
||||
qs_hosts._result_cache = hosts
|
||||
qs_hosts.only.return_value = hosts
|
||||
qs_hosts.__iter__ = lambda self: iter(self._result_cache)
|
||||
qs_hosts.only.return_value = qs_hosts
|
||||
qs_hosts.filter.return_value = qs_hosts
|
||||
qs_hosts.order_by.return_value = qs_hosts
|
||||
qs_hosts.iterator.side_effect = lambda: iter(qs_hosts._result_cache)
|
||||
qs_hosts.count.side_effect = lambda: len(qs_hosts._result_cache)
|
||||
inventory.hosts = qs_hosts
|
||||
|
||||
@@ -154,9 +158,12 @@ def test_pre_post_run_hook_facts_deleted_sliced(
|
||||
host.inventory = mock_inventory
|
||||
hosts.append(host)
|
||||
|
||||
# Mock inventory.hosts behavior
|
||||
# Mock inventory.hosts behavior — must support .only().filter().order_by().iterator() chain
|
||||
mock_qs_hosts = mock.MagicMock()
|
||||
mock_qs_hosts.only.return_value = hosts
|
||||
mock_qs_hosts.only.return_value = mock_qs_hosts
|
||||
mock_qs_hosts.filter.return_value = mock_qs_hosts
|
||||
mock_qs_hosts.order_by.return_value = mock_qs_hosts
|
||||
mock_qs_hosts.iterator.side_effect = lambda: iter(hosts)
|
||||
mock_qs_hosts.count.return_value = 999
|
||||
mock_inventory.hosts = mock_qs_hosts
|
||||
|
||||
@@ -473,7 +480,7 @@ def test_populate_claims_for_adhoc_command(workload_attrs, expected_claims):
|
||||
assert claims == expected_claims
|
||||
|
||||
|
||||
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt returns the JWT string from the client."""
|
||||
mock_client = mock.MagicMock()
|
||||
@@ -502,7 +509,7 @@ def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client)
|
||||
assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_NAME] == 'Test Job'
|
||||
|
||||
|
||||
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt passes audience and scope to the client."""
|
||||
mock_client = mock.MagicMock()
|
||||
@@ -518,7 +525,7 @@ def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_clien
|
||||
mock_client.request_workload_jwt.assert_called_once_with(claims={'job_id': 1}, scope=scope, audience=audience)
|
||||
|
||||
|
||||
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt passes workload_ttl_seconds when provided."""
|
||||
mock_client = mock.Mock()
|
||||
@@ -542,7 +549,7 @@ def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client):
|
||||
)
|
||||
|
||||
|
||||
@mock.patch('awx.main.tasks.jobs.get_workload_identity_client')
|
||||
@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client')
|
||||
def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_get_client):
|
||||
"""retrieve_workload_identity_jwt raises RuntimeError when client is None."""
|
||||
mock_get_client.return_value = None
|
||||
@@ -590,3 +597,67 @@ def test_populate_workload_identity_tokens_passes_get_instance_timeout_to_client
|
||||
scope=AutomationControllerJobScope.name,
|
||||
workload_ttl_seconds=expected_ttl,
|
||||
)
|
||||
|
||||
|
||||
class TestRunInventoryUpdatePopulateWorkloadIdentityTokens:
|
||||
"""Tests for RunInventoryUpdate.populate_workload_identity_tokens."""
|
||||
|
||||
def test_cloud_credential_passed_as_additional_credential(self):
|
||||
"""The cloud credential is forwarded to super().populate_workload_identity_tokens via additional_credentials."""
|
||||
cloud_cred = mock.MagicMock(name='cloud_cred')
|
||||
cloud_cred.context = {}
|
||||
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = cloud_cred
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super:
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
mock_super.assert_called_once_with(additional_credentials=[cloud_cred])
|
||||
|
||||
def test_no_cloud_credential_calls_super_with_none(self):
|
||||
"""When there is no cloud credential, super() is called with additional_credentials=None."""
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = None
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super:
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
mock_super.assert_called_once_with(additional_credentials=None)
|
||||
|
||||
def test_additional_credentials_combined_with_cloud_credential(self):
|
||||
"""Caller-supplied additional_credentials are combined with the cloud credential."""
|
||||
cloud_cred = mock.MagicMock(name='cloud_cred')
|
||||
cloud_cred.context = {}
|
||||
extra_cred = mock.MagicMock(name='extra_cred')
|
||||
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = cloud_cred
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super:
|
||||
task.populate_workload_identity_tokens(additional_credentials=[extra_cred])
|
||||
|
||||
mock_super.assert_called_once_with(additional_credentials=[extra_cred, cloud_cred])
|
||||
|
||||
def test_cloud_credential_override_after_context_set(self):
|
||||
"""After OIDC processing, get_cloud_credential is overridden on the instance when context is populated."""
|
||||
cloud_cred = mock.MagicMock(name='cloud_cred')
|
||||
# Simulate that super().populate_workload_identity_tokens populates context
|
||||
cloud_cred.context = {'workload_identity_token': 'eyJ.test.jwt'}
|
||||
|
||||
task = jobs.RunInventoryUpdate()
|
||||
task.instance = mock.MagicMock()
|
||||
task.instance.get_cloud_credential.return_value = cloud_cred
|
||||
task._credentials = []
|
||||
|
||||
with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens'):
|
||||
task.populate_workload_identity_tokens()
|
||||
|
||||
# The instance's get_cloud_credential should now return the same object with context
|
||||
assert task.instance.get_cloud_credential() is cloud_cred
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
from awx.main.tasks.callback import RunnerCallback
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.tasks.callback import RunnerCallback, try_load_query_file
|
||||
from awx.main.constants import ANSIBLE_RUNNER_NEEDS_UPDATE_MESSAGE
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@@ -50,3 +55,102 @@ def test_special_ansible_runner_message(mock_me):
|
||||
'Traceback:\ngot an unexpected keyword argument\nFile: bar.py\n'
|
||||
f'{ANSIBLE_RUNNER_NEEDS_UPDATE_MESSAGE}'
|
||||
)
|
||||
|
||||
|
||||
SAMPLE_ANSIBLE_DATA = {
|
||||
'installed_collections': {
|
||||
'ansible.builtin': {'version': '2.16.0'},
|
||||
'community.general': {'version': '8.0.0', 'host_query': 'SELECT * FROM hosts'},
|
||||
},
|
||||
'ansible_version': '2.16.0',
|
||||
}
|
||||
|
||||
|
||||
class TestTryLoadQueryFile:
|
||||
def test_loads_file_without_feature_flag(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'ansible_data.json')
|
||||
with open(path, 'w') as f:
|
||||
json.dump(SAMPLE_ANSIBLE_DATA, f)
|
||||
|
||||
with mock.patch('awx.main.tasks.callback.flag_enabled', return_value=False):
|
||||
success, data = try_load_query_file(tmpdir)
|
||||
|
||||
assert success is True
|
||||
assert data['ansible_version'] == '2.16.0'
|
||||
assert 'ansible.builtin' in data['installed_collections']
|
||||
|
||||
def test_loads_file_with_feature_flag(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'ansible_data.json')
|
||||
with open(path, 'w') as f:
|
||||
json.dump(SAMPLE_ANSIBLE_DATA, f)
|
||||
|
||||
with mock.patch('awx.main.tasks.callback.flag_enabled', return_value=True):
|
||||
success, data = try_load_query_file(tmpdir)
|
||||
|
||||
assert success is True
|
||||
assert data == SAMPLE_ANSIBLE_DATA
|
||||
|
||||
def test_returns_false_when_file_missing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
success, data = try_load_query_file(tmpdir)
|
||||
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
|
||||
class TestArtifactsHandler:
|
||||
def test_always_persists_metadata_when_flag_off(self, mock_me):
|
||||
rc = RunnerCallback()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'ansible_data.json')
|
||||
with open(path, 'w') as f:
|
||||
json.dump(SAMPLE_ANSIBLE_DATA, f)
|
||||
|
||||
with mock.patch('awx.main.tasks.callback.flag_enabled', return_value=False):
|
||||
rc.artifacts_handler(tmpdir)
|
||||
|
||||
assert rc.extra_update_fields['installed_collections'] == SAMPLE_ANSIBLE_DATA['installed_collections']
|
||||
assert rc.extra_update_fields['ansible_version'] == '2.16.0'
|
||||
assert 'event_queries_processed' not in rc.extra_update_fields
|
||||
assert rc.artifacts_processed is True
|
||||
|
||||
@mock.patch('awx.main.tasks.callback.EventQuery')
|
||||
def test_creates_event_queries_when_flag_on(self, mock_event_query, mock_me):
|
||||
rc = RunnerCallback()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'ansible_data.json')
|
||||
with open(path, 'w') as f:
|
||||
json.dump(SAMPLE_ANSIBLE_DATA, f)
|
||||
|
||||
with mock.patch('awx.main.tasks.callback.flag_enabled', return_value=True):
|
||||
rc.artifacts_handler(tmpdir)
|
||||
|
||||
assert rc.extra_update_fields['installed_collections'] == SAMPLE_ANSIBLE_DATA['installed_collections']
|
||||
assert rc.extra_update_fields['ansible_version'] == '2.16.0'
|
||||
assert rc.extra_update_fields['event_queries_processed'] is False
|
||||
mock_event_query.assert_called_once()
|
||||
|
||||
@mock.patch('awx.main.tasks.callback.EventQuery')
|
||||
def test_no_event_queries_when_flag_off(self, mock_event_query, mock_me):
|
||||
rc = RunnerCallback()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'ansible_data.json')
|
||||
with open(path, 'w') as f:
|
||||
json.dump(SAMPLE_ANSIBLE_DATA, f)
|
||||
|
||||
with mock.patch('awx.main.tasks.callback.flag_enabled', return_value=False):
|
||||
rc.artifacts_handler(tmpdir)
|
||||
|
||||
mock_event_query.assert_not_called()
|
||||
|
||||
def test_handles_missing_artifact_file(self, mock_me):
|
||||
rc = RunnerCallback()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with mock.patch('awx.main.tasks.callback.flag_enabled', return_value=False):
|
||||
rc.artifacts_handler(tmpdir)
|
||||
|
||||
assert 'installed_collections' not in rc.extra_update_fields
|
||||
assert 'ansible_version' not in rc.extra_update_fields
|
||||
assert rc.artifacts_processed is True
|
||||
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
import awx
|
||||
from awx.main.db.profiled_pg.base import RecordedQueryLog
|
||||
from awx.main.utils.db import db_requirement_violations
|
||||
|
||||
QUERY = {'sql': 'SELECT * FROM main_job', 'time': '.01'}
|
||||
EXPLAIN = 'Seq Scan on public.main_job (cost=0.00..1.18 rows=18 width=86)'
|
||||
@@ -145,3 +146,71 @@ def test_sql_above_threshold(tmpdir):
|
||||
assert q['sql'] == QUERY['sql']
|
||||
assert EXPLAIN in q['explain']
|
||||
assert 'test_sql_above_threshold' in q['bt']
|
||||
|
||||
|
||||
def test_db_requirement_violations_skip_env_var(mocker):
|
||||
mocker.patch.dict(os.environ, {'SKIP_PG_VERSION_CHECK': 'true'})
|
||||
result = db_requirement_violations()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_db_requirement_violations_postgresql_sufficient_version(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'postgresql'
|
||||
mock_connection.pg_version = 120000 # Version 12.0
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_db_requirement_violations_postgresql_insufficient_version(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'postgresql'
|
||||
mock_connection.pg_version = 110000 # Version 11.0
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is not None
|
||||
assert "At a minimum, postgres version 12 is required, found 11" in result
|
||||
|
||||
|
||||
def test_db_requirement_violations_non_postgresql_production(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'sqlite'
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch('awx.main.utils.db.MODE', 'production')
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is not None
|
||||
assert "Running server with 'sqlite' type database is not supported" in result
|
||||
|
||||
|
||||
def test_db_requirement_violations_non_postgresql_development(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'sqlite'
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch('awx.main.utils.db.MODE', 'development')
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_db_requirement_violations_postgresql_edge_case_version(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'postgresql'
|
||||
mock_connection.pg_version = 129999 # Version 12.9999
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -39,6 +39,13 @@ def create_queries_dir_mock(file_lookup_func):
|
||||
class MockCallbackBase:
|
||||
def __init__(self):
|
||||
self._display = mock.MagicMock()
|
||||
self._plugin_options = {}
|
||||
|
||||
def get_option(self, key):
|
||||
return self._plugin_options.get(key)
|
||||
|
||||
def set_option(self, key, value):
|
||||
self._plugin_options[key] = value
|
||||
|
||||
def v2_playbook_on_stats(self, stats):
|
||||
pass
|
||||
@@ -289,6 +296,7 @@ class TestExternalQueryDiscovery:
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
callback.set_option('collect_host_queries', True)
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
@@ -318,6 +326,7 @@ class TestExternalQueryDiscovery:
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
callback.set_option('collect_host_queries', True)
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
@@ -342,6 +351,7 @@ class TestExternalQueryDiscovery:
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
callback.set_option('collect_host_queries', True)
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
@@ -372,6 +382,7 @@ class TestExternalQueryDiscovery:
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
callback.set_option('collect_host_queries', True)
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
@@ -382,6 +393,28 @@ class TestExternalQueryDiscovery:
|
||||
assert '4.1.0' in call_args
|
||||
assert 'community.vmware' in call_args
|
||||
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.list_collections')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.files')
|
||||
@mock.patch('awx.playbooks.library.indirect_instance_count.find_external_query_with_fallback')
|
||||
@mock.patch.dict('os.environ', {'AWX_ISOLATED_DATA_DIR': '/tmp/artifacts'})
|
||||
def test_queries_not_collected_when_option_disabled(self, mock_fallback, mock_files, mock_list_collections):
|
||||
"""Host query scanning is skipped when collect_host_queries is disabled."""
|
||||
from awx.playbooks.library.indirect_instance_count import CallbackModule
|
||||
|
||||
mock_list_collections.return_value = [mock.Mock(namespace='demo', name='query', ver='1.0.0', fqcn='demo.query')]
|
||||
|
||||
callback = CallbackModule()
|
||||
callback._display = mock.Mock()
|
||||
callback.set_option('collect_host_queries', False)
|
||||
|
||||
with mock.patch('builtins.open', mock.mock_open()):
|
||||
with mock.patch('json.dumps', return_value='{}'):
|
||||
callback.v2_playbook_on_stats(mock.Mock())
|
||||
|
||||
mock_list_collections.assert_called_once()
|
||||
mock_files.assert_not_called()
|
||||
mock_fallback.assert_not_called()
|
||||
|
||||
|
||||
class TestPrivateDataDirIntegration:
|
||||
"""Tests for vendor collection copying (AC7.10-AC7.11)."""
|
||||
|
||||
@@ -37,7 +37,7 @@ from awx.main.utils import encrypt_field, encrypt_value
|
||||
from awx.main.utils.safe_yaml import SafeLoader
|
||||
|
||||
from awx.main.utils.licensing import Licenser
|
||||
from awx.main.constants import JOB_VARIABLE_PREFIXES
|
||||
from awx.main.utils.common import get_job_variable_prefixes
|
||||
|
||||
from receptorctl.socket_interface import ReceptorControl
|
||||
|
||||
@@ -372,12 +372,12 @@ class TestExtraVarSanitation(TestJobExecution):
|
||||
extra_vars = yaml.load(fd, Loader=SafeLoader)
|
||||
|
||||
# ensure that strings are marked as unsafe
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
for variable_name in ['_job_template_name', '_user_name', '_job_launch_type', '_project_revision', '_inventory_name']:
|
||||
assert hasattr(extra_vars['{}{}'.format(name, variable_name)], '__UNSAFE__')
|
||||
|
||||
# ensure that non-strings are marked as safe
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
for variable_name in ['_job_template_id', '_job_id', '_user_id', '_inventory_id']:
|
||||
assert not hasattr(extra_vars['{}{}'.format(name, variable_name)], '__UNSAFE__')
|
||||
|
||||
@@ -524,7 +524,7 @@ class TestGenericRun:
|
||||
call_args, _ = task._write_extra_vars_file.call_args_list[0]
|
||||
|
||||
private_data_dir, extra_vars, safe_dict = call_args
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
assert extra_vars['{}_user_id'.format(name)] == 123
|
||||
assert extra_vars['{}_user_name'.format(name)] == "angry-spud"
|
||||
|
||||
@@ -615,7 +615,7 @@ class TestAdhocRun(TestJobExecution):
|
||||
call_args, _ = task._write_extra_vars_file.call_args_list[0]
|
||||
|
||||
private_data_dir, extra_vars = call_args
|
||||
for name in JOB_VARIABLE_PREFIXES:
|
||||
for name in get_job_variable_prefixes():
|
||||
assert extra_vars['{}_user_id'.format(name)] == 123
|
||||
assert extra_vars['{}_user_name'.format(name)] == "angry-spud"
|
||||
|
||||
@@ -918,6 +918,81 @@ class TestJobCredentials(TestJobExecution):
|
||||
assert env['FOO'] == 'BAR'
|
||||
|
||||
|
||||
class TestCallbacksEnabled(TestJobExecution):
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_flag_enabled(self):
|
||||
with mock.patch('awx.main.tasks.jobs.flag_enabled', return_value=False):
|
||||
yield
|
||||
|
||||
def test_callbacks_enabled_default(self, patch_Job, private_data_dir, execution_environment, mock_me):
|
||||
job = Job(project=Project(), inventory=Inventory())
|
||||
job.execution_environment = execution_environment
|
||||
|
||||
task = jobs.RunJob()
|
||||
task.instance = job
|
||||
task._write_extra_vars_file = mock.Mock()
|
||||
|
||||
with mock.patch.object(task, 'build_credentials_list', return_value=[], autospec=True):
|
||||
env = task.build_env(job, private_data_dir)
|
||||
|
||||
assert env['ANSIBLE_CALLBACKS_ENABLED'] == 'indirect_instance_count'
|
||||
|
||||
def test_callbacks_enabled_preserves_user_config(self, patch_Job, private_data_dir, execution_environment, mock_me):
|
||||
job = Job(project=Project(), inventory=Inventory())
|
||||
job.execution_environment = execution_environment
|
||||
|
||||
task = jobs.RunJob()
|
||||
task.instance = job
|
||||
task._write_extra_vars_file = mock.Mock()
|
||||
|
||||
with mock.patch.object(task, 'build_credentials_list', return_value=[], autospec=True):
|
||||
with mock.patch('awx.main.tasks.jobs.read_ansible_config', return_value={'callbacks_enabled': 'custom_callback,another_callback'}):
|
||||
env = task.build_env(job, private_data_dir)
|
||||
|
||||
assert env['ANSIBLE_CALLBACKS_ENABLED'] == 'indirect_instance_count,custom_callback,another_callback'
|
||||
|
||||
def test_callbacks_enabled_uses_comma_delimiter(self, patch_Job, private_data_dir, execution_environment, mock_me):
|
||||
job = Job(project=Project(), inventory=Inventory())
|
||||
job.execution_environment = execution_environment
|
||||
|
||||
task = jobs.RunJob()
|
||||
task.instance = job
|
||||
task._write_extra_vars_file = mock.Mock()
|
||||
|
||||
with mock.patch.object(task, 'build_credentials_list', return_value=[], autospec=True):
|
||||
with mock.patch('awx.main.tasks.jobs.read_ansible_config', return_value={'callbacks_enabled': 'my_callback'}):
|
||||
env = task.build_env(job, private_data_dir)
|
||||
|
||||
assert env['ANSIBLE_CALLBACKS_ENABLED'] == 'indirect_instance_count,my_callback'
|
||||
|
||||
def test_collect_host_queries_set_when_flag_on(self, patch_Job, private_data_dir, execution_environment, mock_me):
|
||||
job = Job(project=Project(), inventory=Inventory())
|
||||
job.execution_environment = execution_environment
|
||||
|
||||
task = jobs.RunJob()
|
||||
task.instance = job
|
||||
task._write_extra_vars_file = mock.Mock()
|
||||
|
||||
with mock.patch.object(task, 'build_credentials_list', return_value=[], autospec=True):
|
||||
with mock.patch('awx.main.tasks.jobs.flag_enabled', return_value=True):
|
||||
env = task.build_env(job, private_data_dir)
|
||||
|
||||
assert env['AWX_COLLECT_HOST_QUERIES'] == '1'
|
||||
|
||||
def test_collect_host_queries_not_set_when_flag_off(self, patch_Job, private_data_dir, execution_environment, mock_me):
|
||||
job = Job(project=Project(), inventory=Inventory())
|
||||
job.execution_environment = execution_environment
|
||||
|
||||
task = jobs.RunJob()
|
||||
task.instance = job
|
||||
task._write_extra_vars_file = mock.Mock()
|
||||
|
||||
with mock.patch.object(task, 'build_credentials_list', return_value=[], autospec=True):
|
||||
env = task.build_env(job, private_data_dir)
|
||||
|
||||
assert 'AWX_COLLECT_HOST_QUERIES' not in env
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_Organization")
|
||||
class TestProjectUpdateGalaxyCredentials(TestJobExecution):
|
||||
@pytest.fixture
|
||||
|
||||
@@ -0,0 +1,383 @@
|
||||
# Copyright (c) 2026 Ansible, Inc.
|
||||
# All Rights Reserved.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.utils.candlepin import (
|
||||
_discover_org,
|
||||
_fetch_candlepin_cert_from_db,
|
||||
_fetch_registration_credentials_from_db,
|
||||
_save_candlepin_cert_to_db,
|
||||
_save_candlepin_registration_to_db,
|
||||
_register_candlepin_consumer,
|
||||
_run_candlepin_lifecycle,
|
||||
get_or_generate_candlepin_certificate,
|
||||
resolve_registration_credentials,
|
||||
)
|
||||
|
||||
|
||||
class TestCandlepinCertificateRegistration:
|
||||
"""Tests for Candlepin integration in certificate registration module."""
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.requests.get')
|
||||
@mock.patch('awx.main.utils.candlepin.get_candlepin_ca')
|
||||
def test_discover_org_success(self, mock_get_ca, mock_requests_get):
|
||||
"""Test successful organization discovery."""
|
||||
mock_get_ca.return_value = '/path/to/ca.pem'
|
||||
mock_response = mock.Mock()
|
||||
mock_response.json.return_value = [
|
||||
{'key': 'test_org', 'displayName': 'Test Organization'},
|
||||
{'key': 'other_org', 'displayName': 'Other Organization'},
|
||||
]
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
org = _discover_org('https://candlepin.example.com', 'test_user', 'test_pass')
|
||||
|
||||
assert org == 'test_org'
|
||||
mock_requests_get.assert_called_once_with(
|
||||
'https://candlepin.example.com/users/test_user/owners',
|
||||
auth=('test_user', 'test_pass'),
|
||||
verify='/path/to/ca.pem',
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.requests.get')
|
||||
@mock.patch('awx.main.utils.candlepin.get_candlepin_ca')
|
||||
def test_discover_org_no_ca(self, mock_get_ca, mock_requests_get):
|
||||
"""Test organization discovery without custom CA (uses system certs)."""
|
||||
mock_get_ca.return_value = None
|
||||
mock_response = mock.Mock()
|
||||
mock_response.json.return_value = [{'key': 'test_org', 'displayName': 'Test Organization'}]
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
org = _discover_org('https://candlepin.example.com', 'test_user', 'test_pass')
|
||||
|
||||
assert org == 'test_org'
|
||||
# Should use True for verify when no CA is configured
|
||||
mock_requests_get.assert_called_once_with(
|
||||
'https://candlepin.example.com/users/test_user/owners',
|
||||
auth=('test_user', 'test_pass'),
|
||||
verify=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.requests.get')
|
||||
def test_discover_org_no_verify_tls(self, mock_requests_get):
|
||||
"""Test organization discovery with TLS verification disabled."""
|
||||
mock_response = mock.Mock()
|
||||
mock_response.json.return_value = [{'key': 'test_org', 'displayName': 'Test Organization'}]
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
org = _discover_org('https://candlepin.example.com', 'test_user', 'test_pass', verify_tls=False)
|
||||
|
||||
assert org == 'test_org'
|
||||
# Should use False for verify when verify_tls=False
|
||||
mock_requests_get.assert_called_once_with(
|
||||
'https://candlepin.example.com/users/test_user/owners',
|
||||
auth=('test_user', 'test_pass'),
|
||||
verify=False,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.settings')
|
||||
def test_fetch_candlepin_cert_from_db(self, mock_settings):
|
||||
"""Test fetching Candlepin cert from conf_settings."""
|
||||
mock_settings.CANDLEPIN_CONSUMER_UUID = 'test-uuid'
|
||||
mock_settings.CANDLEPIN_CERT_PEM = 'cert-pem-data'
|
||||
mock_settings.CANDLEPIN_KEY_PEM = 'key-pem-data'
|
||||
|
||||
cert, key, uuid = _fetch_candlepin_cert_from_db()
|
||||
|
||||
assert cert == 'cert-pem-data'
|
||||
assert key == 'key-pem-data'
|
||||
assert uuid == 'test-uuid'
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._discover_org')
|
||||
@mock.patch('awx.main.utils.candlepin.settings')
|
||||
def test_fetch_registration_credentials_from_db(self, mock_settings, mock_discover_org):
|
||||
"""Test fetching registration credentials from settings.
|
||||
|
||||
When both REDHAT and SUBSCRIPTIONS credentials exist, REDHAT takes priority
|
||||
for both authentication and org discovery.
|
||||
"""
|
||||
mock_settings.REDHAT_USERNAME = 'test_user'
|
||||
mock_settings.REDHAT_PASSWORD = 'test_pass'
|
||||
mock_settings.INSTALL_UUID = 'test-install-uuid'
|
||||
mock_settings.SUBSCRIPTIONS_USERNAME = 'subs_user'
|
||||
mock_settings.SUBSCRIPTIONS_PASSWORD = 'subs_pass'
|
||||
mock_discover_org.return_value = 'test_org'
|
||||
|
||||
username, password, org, install_uuid = _fetch_registration_credentials_from_db()
|
||||
|
||||
assert username == 'test_user'
|
||||
assert password == 'test_pass'
|
||||
assert org == 'test_org'
|
||||
assert install_uuid == 'test-install-uuid'
|
||||
# Verify _discover_org was called with REDHAT credentials (takes priority)
|
||||
assert mock_discover_org.call_count == 1
|
||||
args = mock_discover_org.call_args[0]
|
||||
assert args[1] == 'test_user' # REDHAT_USERNAME (selected)
|
||||
assert args[2] == 'test_pass' # REDHAT_PASSWORD (selected)
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._discover_org')
|
||||
@mock.patch('awx.main.utils.candlepin.settings')
|
||||
def test_fetch_registration_credentials_no_verify_tls(self, mock_settings, mock_discover_org):
|
||||
"""Test fetching credentials passes verify_tls=False to _discover_org.
|
||||
|
||||
Also verifies that selected credentials (REDHAT in this case) are used for org discovery.
|
||||
"""
|
||||
mock_settings.REDHAT_USERNAME = 'test_user'
|
||||
mock_settings.REDHAT_PASSWORD = 'test_pass'
|
||||
mock_settings.INSTALL_UUID = 'test-install-uuid'
|
||||
mock_settings.SUBSCRIPTIONS_USERNAME = 'subs_user'
|
||||
mock_settings.SUBSCRIPTIONS_PASSWORD = 'subs_pass'
|
||||
mock_discover_org.return_value = 'test_org'
|
||||
|
||||
username, password, org, install_uuid = _fetch_registration_credentials_from_db(verify_tls=False)
|
||||
|
||||
assert username == 'test_user'
|
||||
assert password == 'test_pass'
|
||||
assert org == 'test_org'
|
||||
assert install_uuid == 'test-install-uuid'
|
||||
# Verify _discover_org was called with verify_tls=False and REDHAT credentials
|
||||
mock_discover_org.assert_called_once()
|
||||
call_args = mock_discover_org.call_args
|
||||
assert call_args[0][1] == 'test_user' # REDHAT_USERNAME (selected)
|
||||
assert call_args[0][2] == 'test_pass' # REDHAT_PASSWORD (selected)
|
||||
call_kwargs = call_args[1]
|
||||
assert call_kwargs['verify_tls'] is False
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_registration_credentials_from_db')
|
||||
def test_resolve_registration_credentials_no_overrides(self, mock_fetch):
|
||||
"""Test resolve_registration_credentials with no overrides."""
|
||||
mock_fetch.return_value = ('db_user', 'db_pass', 'db_org', 'install-uuid')
|
||||
|
||||
username, password, org, install_uuid, errors = resolve_registration_credentials()
|
||||
|
||||
assert username == 'db_user'
|
||||
assert password == 'db_pass'
|
||||
assert org == 'db_org'
|
||||
assert install_uuid == 'install-uuid'
|
||||
assert errors is None
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_registration_credentials_from_db')
|
||||
def test_resolve_registration_credentials_with_overrides(self, mock_fetch):
|
||||
"""Test resolve_registration_credentials with CLI overrides."""
|
||||
mock_fetch.return_value = ('db_user', 'db_pass', 'db_org', 'install-uuid')
|
||||
|
||||
username, password, org, install_uuid, errors = resolve_registration_credentials(
|
||||
username_override='cli_user', password_override='cli_pass', org_override='cli_org'
|
||||
)
|
||||
|
||||
assert username == 'cli_user'
|
||||
assert password == 'cli_pass'
|
||||
assert org == 'cli_org'
|
||||
assert install_uuid == 'install-uuid'
|
||||
assert errors is None
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_registration_credentials_from_db')
|
||||
def test_resolve_registration_credentials_verify_tls_false(self, mock_fetch):
|
||||
"""Test resolve_registration_credentials passes verify_tls=False to fetch function."""
|
||||
mock_fetch.return_value = ('db_user', 'db_pass', 'db_org', 'install-uuid')
|
||||
|
||||
username, password, org, install_uuid, errors = resolve_registration_credentials(verify_tls=False)
|
||||
|
||||
# Verify _fetch_registration_credentials_from_db was called with verify_tls=False
|
||||
mock_fetch.assert_called_once_with(verify_tls=False)
|
||||
assert username == 'db_user'
|
||||
assert password == 'db_pass'
|
||||
assert org == 'db_org'
|
||||
assert install_uuid == 'install-uuid'
|
||||
assert errors is None
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.parse_cert')
|
||||
@mock.patch('awx.main.utils.candlepin.settings')
|
||||
def test_save_candlepin_cert_to_db(self, mock_settings, mock_parse_cert):
|
||||
"""Test saving Candlepin cert to conf_settings."""
|
||||
mock_parse_cert.return_value = {
|
||||
'serial': '123456',
|
||||
'cn': 'test-consumer',
|
||||
'not_before': '2026-01-01T00:00:00+00:00',
|
||||
'not_after': '2027-01-01T00:00:00+00:00',
|
||||
'days_remaining': 365,
|
||||
}
|
||||
|
||||
result = _save_candlepin_cert_to_db('new-cert', 'new-key')
|
||||
|
||||
assert result is True
|
||||
# Verify settings were assigned
|
||||
assert mock_settings.CANDLEPIN_CERT_PEM == 'new-cert'
|
||||
assert mock_settings.CANDLEPIN_KEY_PEM == 'new-key'
|
||||
assert mock_settings.CANDLEPIN_SERIAL_NUMBER == '123456'
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.parse_cert')
|
||||
@mock.patch('awx.main.utils.candlepin.settings')
|
||||
def test_save_candlepin_registration_to_db(self, mock_settings, mock_parse_cert):
|
||||
"""Test saving Candlepin registration to conf_settings."""
|
||||
mock_parse_cert.return_value = {
|
||||
'serial': '789012',
|
||||
'cn': 'test-consumer',
|
||||
'not_before': '2026-01-01T00:00:00+00:00',
|
||||
'not_after': '2027-01-01T00:00:00+00:00',
|
||||
'days_remaining': 365,
|
||||
}
|
||||
|
||||
result = _save_candlepin_registration_to_db('cert', 'key', 'uuid')
|
||||
|
||||
assert result is True
|
||||
# Verify all registration data was saved
|
||||
assert mock_settings.CANDLEPIN_CONSUMER_UUID == 'uuid'
|
||||
assert mock_settings.CANDLEPIN_CERT_PEM == 'cert'
|
||||
assert mock_settings.CANDLEPIN_KEY_PEM == 'key'
|
||||
assert mock_settings.CANDLEPIN_SERIAL_NUMBER == '789012'
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._save_candlepin_registration_to_db')
|
||||
@mock.patch('awx.main.utils.candlepin.CandlepinClient')
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_registration_credentials_from_db')
|
||||
@mock.patch('awx.main.utils.candlepin.get_proxy_url')
|
||||
@mock.patch('awx.main.utils.candlepin.get_candlepin_ca')
|
||||
@mock.patch('awx.main.utils.candlepin.get_candlepin_url')
|
||||
def test_register_candlepin_consumer_success(self, mock_get_url, mock_get_ca, mock_get_proxy, mock_fetch_creds, mock_client_class, mock_save):
|
||||
"""Test successful Candlepin consumer registration."""
|
||||
mock_get_url.return_value = 'https://candlepin.example.com'
|
||||
mock_get_ca.return_value = '/path/to/ca.pem'
|
||||
mock_get_proxy.return_value = None
|
||||
mock_fetch_creds.return_value = ('user', 'pass', 'org', 'install-uuid')
|
||||
mock_save.return_value = True
|
||||
|
||||
mock_client = mock.Mock()
|
||||
mock_client.register_consumer.return_value = ('cert', 'key', 'uuid')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
cert, key, uuid = _register_candlepin_consumer()
|
||||
|
||||
assert cert == 'cert'
|
||||
assert key == 'key'
|
||||
assert uuid == 'uuid'
|
||||
mock_save.assert_called_once_with('cert', 'key', 'uuid')
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_registration_credentials_from_db')
|
||||
def test_register_candlepin_consumer_missing_credentials(self, mock_fetch_creds):
|
||||
"""Test registration fails when credentials are missing."""
|
||||
mock_fetch_creds.return_value = (None, None, None, None)
|
||||
|
||||
cert, key, uuid = _register_candlepin_consumer()
|
||||
|
||||
assert cert is None
|
||||
assert key is None
|
||||
assert uuid is None
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._save_candlepin_cert_to_db')
|
||||
@mock.patch('awx.main.utils.candlepin.run_candlepin_lifecycle')
|
||||
@mock.patch('awx.main.utils.candlepin.get_proxy_url')
|
||||
@mock.patch('awx.main.utils.candlepin.get_candlepin_ca')
|
||||
@mock.patch('awx.main.utils.candlepin.get_renewal_days')
|
||||
@mock.patch('awx.main.utils.candlepin.get_candlepin_url')
|
||||
def test_run_candlepin_lifecycle_with_renewal(self, mock_get_url, mock_get_days, mock_get_ca, mock_get_proxy, mock_lifecycle, mock_save):
|
||||
"""Test lifecycle with certificate renewal."""
|
||||
mock_get_url.return_value = 'https://candlepin.example.com'
|
||||
mock_get_days.return_value = 90
|
||||
mock_get_ca.return_value = '/path/to/ca.pem'
|
||||
mock_get_proxy.return_value = None
|
||||
mock_lifecycle.return_value = ('new-cert', 'new-key')
|
||||
mock_save.return_value = True
|
||||
|
||||
cert, key = _run_candlepin_lifecycle('old-cert', 'old-key', 'real-uuid')
|
||||
|
||||
assert cert == 'new-cert'
|
||||
assert key == 'new-key'
|
||||
mock_lifecycle.assert_called_once()
|
||||
mock_save.assert_called_once_with('new-cert', 'new-key')
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.is_cert_valid')
|
||||
@mock.patch('awx.main.utils.candlepin._run_candlepin_lifecycle')
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_candlepin_cert_from_db')
|
||||
def test_get_or_generate_candlepin_certificate_existing_valid(self, mock_fetch, mock_lifecycle, mock_is_valid):
|
||||
"""Test get_or_generate with existing valid certificate."""
|
||||
mock_fetch.return_value = ('cert-pem', 'key-pem', 'consumer-uuid')
|
||||
mock_lifecycle.return_value = ('cert-pem', 'key-pem')
|
||||
mock_is_valid.return_value = True
|
||||
|
||||
cert, key = get_or_generate_candlepin_certificate()
|
||||
|
||||
assert cert == 'cert-pem'
|
||||
assert key == 'key-pem'
|
||||
mock_lifecycle.assert_called_once_with('cert-pem', 'key-pem', 'consumer-uuid')
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.is_cert_valid')
|
||||
@mock.patch('awx.main.utils.candlepin._run_candlepin_lifecycle')
|
||||
@mock.patch('awx.main.utils.candlepin._register_candlepin_consumer')
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_candlepin_cert_from_db')
|
||||
def test_get_or_generate_candlepin_certificate_register_new(self, mock_fetch, mock_register, mock_lifecycle, mock_is_valid):
|
||||
"""Test get_or_generate when no certificate exists - registers new."""
|
||||
mock_fetch.return_value = (None, None, None)
|
||||
mock_register.return_value = ('new-cert', 'new-key', 'new-uuid')
|
||||
mock_lifecycle.return_value = ('new-cert', 'new-key')
|
||||
mock_is_valid.return_value = True
|
||||
|
||||
cert, key = get_or_generate_candlepin_certificate()
|
||||
|
||||
assert cert == 'new-cert'
|
||||
assert key == 'new-key'
|
||||
mock_register.assert_called_once()
|
||||
mock_lifecycle.assert_called_once_with('new-cert', 'new-key', 'new-uuid')
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin._register_candlepin_consumer')
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_candlepin_cert_from_db')
|
||||
def test_get_or_generate_candlepin_certificate_registration_fails(self, mock_fetch, mock_register):
|
||||
"""Test get_or_generate when registration fails."""
|
||||
mock_fetch.return_value = (None, None, None)
|
||||
mock_register.return_value = (None, None, None)
|
||||
|
||||
cert, key = get_or_generate_candlepin_certificate()
|
||||
|
||||
assert cert is None
|
||||
assert key is None
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.is_cert_valid')
|
||||
@mock.patch('awx.main.utils.candlepin._run_candlepin_lifecycle')
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_candlepin_cert_from_db')
|
||||
def test_get_or_generate_candlepin_certificate_invalid_cert(self, mock_fetch, mock_lifecycle, mock_is_valid):
|
||||
"""Test get_or_generate when certificate is invalid."""
|
||||
mock_fetch.return_value = ('cert-pem', 'key-pem', 'consumer-uuid')
|
||||
mock_lifecycle.return_value = ('cert-pem', 'key-pem')
|
||||
mock_is_valid.return_value = False
|
||||
|
||||
cert, key = get_or_generate_candlepin_certificate()
|
||||
|
||||
assert cert is None
|
||||
assert key is None
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.is_cert_valid')
|
||||
@mock.patch('awx.main.utils.candlepin._run_candlepin_lifecycle')
|
||||
@mock.patch('awx.main.utils.candlepin._fetch_candlepin_cert_from_db')
|
||||
def test_get_or_generate_candlepin_certificate_expired_cert_renewed_successfully(self, mock_fetch, mock_lifecycle, mock_is_valid):
|
||||
"""Test get_or_generate with expired certificate that is successfully renewed."""
|
||||
mock_fetch.return_value = ('expired-cert', 'old-key', 'consumer-uuid')
|
||||
# Lifecycle successfully renews
|
||||
mock_lifecycle.return_value = ('new-cert', 'new-key')
|
||||
# New certificate is valid
|
||||
mock_is_valid.return_value = True
|
||||
|
||||
cert, key = get_or_generate_candlepin_certificate()
|
||||
|
||||
assert cert == 'new-cert'
|
||||
assert key == 'new-key'
|
||||
mock_lifecycle.assert_called_once_with('expired-cert', 'old-key', 'consumer-uuid')
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.parse_cert')
|
||||
@mock.patch('awx.main.utils.candlepin.settings')
|
||||
def test_save_candlepin_registration_to_db_cert_parse_failure(self, mock_settings, mock_parse_cert):
|
||||
"""Test _save_candlepin_registration_to_db handles cert parsing failure gracefully."""
|
||||
# Cert parsing fails
|
||||
mock_parse_cert.side_effect = ValueError('Invalid certificate format')
|
||||
|
||||
result = _save_candlepin_registration_to_db('invalid-cert', 'key-pem', 'consumer-uuid')
|
||||
|
||||
# Should still save registration even if parsing fails
|
||||
assert result is True
|
||||
# Verify UUID, cert, key, and serial (empty string) were saved
|
||||
assert mock_settings.CANDLEPIN_CONSUMER_UUID == 'consumer-uuid'
|
||||
assert mock_settings.CANDLEPIN_CERT_PEM == 'invalid-cert'
|
||||
assert mock_settings.CANDLEPIN_KEY_PEM == 'key-pem'
|
||||
assert mock_settings.CANDLEPIN_SERIAL_NUMBER == ''
|
||||
124
awx/main/tests/unit/utils/test_candlepin_client.py
Normal file
124
awx/main/tests/unit/utils/test_candlepin_client.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2026 Ansible, Inc.
|
||||
# All Rights Reserved.
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.utils.candlepin.client import CandlepinClient, _temp_cert_files
|
||||
|
||||
|
||||
class TestCandlepinClient:
|
||||
"""Tests for CandlepinClient."""
|
||||
|
||||
def test_base_url_required(self):
|
||||
"""Test base_url parameter is required."""
|
||||
client = CandlepinClient(base_url='https://subscription.example.com/candlepin')
|
||||
assert client.base_url == 'https://subscription.example.com/candlepin'
|
||||
|
||||
def test_verify_tls_enabled_by_default(self):
|
||||
"""Test TLS verification is enabled by default."""
|
||||
client = CandlepinClient(base_url='https://test.example.com')
|
||||
assert client.verify is True
|
||||
|
||||
def test_verify_tls_with_ca(self):
|
||||
"""Test TLS verification with custom CA."""
|
||||
client = CandlepinClient(base_url='https://test.example.com', candlepin_ca='/path/to/ca.pem')
|
||||
assert client.verify == '/path/to/ca.pem'
|
||||
|
||||
def test_proxy_configuration(self):
|
||||
"""Test proxy configuration."""
|
||||
client = CandlepinClient(base_url='https://test.example.com', proxy='http://proxy.example.com:8080')
|
||||
assert client.proxies == {'https': 'http://proxy.example.com:8080', 'http': 'http://proxy.example.com:8080'}
|
||||
|
||||
def test_temp_cert_files_cleanup(self):
|
||||
"""Test temporary certificate files are created and cleaned up."""
|
||||
cert_pem = '-----BEGIN CERTIFICATE-----\ntest_cert\n-----END CERTIFICATE-----'
|
||||
key_pem = '-----BEGIN PRIVATE KEY-----\ntest_key\n-----END PRIVATE KEY-----'
|
||||
|
||||
with _temp_cert_files(cert_pem, key_pem) as (cert_path, key_path):
|
||||
assert os.path.exists(cert_path)
|
||||
assert os.path.exists(key_path)
|
||||
# Verify file permissions
|
||||
cert_stat = os.stat(cert_path)
|
||||
assert oct(cert_stat.st_mode)[-3:] == '600'
|
||||
|
||||
# Verify cleanup
|
||||
assert not os.path.exists(cert_path)
|
||||
assert not os.path.exists(key_path)
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.client.requests.post')
|
||||
def test_register_consumer_success(self, mock_post):
|
||||
"""Test successful consumer registration."""
|
||||
mock_response = mock.Mock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
'uuid': 'test-consumer-uuid',
|
||||
'idCert': {
|
||||
'cert': '-----BEGIN CERTIFICATE-----\ncert_data\n-----END CERTIFICATE-----',
|
||||
'key': '-----BEGIN PRIVATE KEY-----\nkey_data\n-----END PRIVATE KEY-----',
|
||||
},
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = CandlepinClient(base_url='https://test.example.com')
|
||||
cert_pem, key_pem, consumer_uuid = client.register_consumer('test_user', 'test_pass', 'test_org', install_uuid='test-install-uuid')
|
||||
|
||||
assert consumer_uuid == 'test-consumer-uuid'
|
||||
assert '-----BEGIN CERTIFICATE-----' in cert_pem
|
||||
assert '-----BEGIN PRIVATE KEY-----' in key_pem
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.client.requests.put')
|
||||
def test_checkin_success(self, mock_put):
|
||||
"""Test successful check-in."""
|
||||
mock_response = mock.Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_put.return_value = mock_response
|
||||
|
||||
client = CandlepinClient(base_url='https://test.example.com')
|
||||
cert_pem = '-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----'
|
||||
key_pem = '-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----'
|
||||
|
||||
result = client.checkin('test-uuid', cert_pem, key_pem)
|
||||
assert result is True
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.client.requests.get')
|
||||
def test_get_consumer_success(self, mock_get):
|
||||
"""Test successful consumer retrieval."""
|
||||
mock_response = mock.Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'uuid': 'test-consumer-uuid',
|
||||
'name': 'aap-12345678',
|
||||
'idCert': {'cert': '-----BEGIN CERTIFICATE-----\nserver_cert\n-----END CERTIFICATE-----', 'serial': {'serial': 123456789}},
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
client = CandlepinClient(base_url='https://test.example.com')
|
||||
cert_pem = '-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----'
|
||||
key_pem = '-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----'
|
||||
|
||||
result = client.get_consumer('test-uuid', cert_pem, key_pem)
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'test-consumer-uuid'
|
||||
assert 'idCert' in result
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.client.requests.post')
|
||||
def test_regenerate_cert_success(self, mock_post):
|
||||
"""Test successful certificate regeneration."""
|
||||
mock_response = mock.Mock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
'idCert': {
|
||||
'cert': '-----BEGIN CERTIFICATE-----\nnew_cert\n-----END CERTIFICATE-----',
|
||||
'key': '-----BEGIN PRIVATE KEY-----\nnew_key\n-----END PRIVATE KEY-----',
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = CandlepinClient(base_url='https://test.example.com')
|
||||
old_cert = '-----BEGIN CERTIFICATE-----\nold\n-----END CERTIFICATE-----'
|
||||
old_key = '-----BEGIN PRIVATE KEY-----\nold\n-----END PRIVATE KEY-----'
|
||||
|
||||
new_cert, new_key = client.regenerate_cert('test-uuid', old_cert, old_key)
|
||||
assert 'new_cert' in new_cert
|
||||
assert 'new_key' in new_key
|
||||
222
awx/main/tests/unit/utils/test_candlepin_lifecycle.py
Normal file
222
awx/main/tests/unit/utils/test_candlepin_lifecycle.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2026 Ansible, Inc.
|
||||
# All Rights Reserved.
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.utils.candlepin.lifecycle import (
|
||||
parse_cert,
|
||||
needs_renewal,
|
||||
run_candlepin_lifecycle,
|
||||
get_candlepin_url,
|
||||
get_renewal_days,
|
||||
get_candlepin_ca,
|
||||
get_proxy_url,
|
||||
)
|
||||
|
||||
# Sample test certificate (expires far in the future for testing)
|
||||
SAMPLE_CERT_PEM = """-----BEGIN CERTIFICATE-----
|
||||
MIIDXTCCAkWgAwIBAgIJAKJ5VZ2cPQE5MA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV
|
||||
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
|
||||
aWRnaXRzIFB0eSBMdGQwHhcNMjYwMTAxMDAwMDAwWhcNMjcwMTAxMDAwMDAwWjBF
|
||||
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
|
||||
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
|
||||
CgKCAQEA0a7Y3l3X4L7pKq3xDl8vCRrRK6qU5dF7r3xQH5YRz4hZJN9wE3xW0qDT
|
||||
-----END CERTIFICATE-----"""
|
||||
|
||||
|
||||
class TestCandlepinLifecycle:
|
||||
"""Tests for Candlepin lifecycle functions."""
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.settings')
|
||||
def test_get_candlepin_url_default(self, mock_settings):
|
||||
"""Test default Candlepin URL from defaults.py."""
|
||||
mock_settings.AWX_ANALYTICS_CANDLEPIN_URL = 'https://subscription.example.com/candlepin/'
|
||||
url = get_candlepin_url()
|
||||
assert url == 'https://subscription.example.com/candlepin/'
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.settings')
|
||||
def test_get_renewal_days_from_settings(self, mock_settings):
|
||||
"""Test renewal days from Django settings."""
|
||||
mock_settings.AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS = 45
|
||||
days = get_renewal_days()
|
||||
assert days == 45
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.os.path.isfile')
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.settings')
|
||||
def test_get_candlepin_ca_from_settings(self, mock_settings, mock_isfile):
|
||||
"""Test Candlepin CA from Django settings when file exists."""
|
||||
mock_settings.AWX_ANALYTICS_CANDLEPIN_CA = '/path/to/ca.pem'
|
||||
mock_isfile.return_value = True
|
||||
ca = get_candlepin_ca()
|
||||
assert ca == '/path/to/ca.pem'
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.os.path.isfile')
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.settings')
|
||||
def test_get_candlepin_ca_file_not_found(self, mock_settings, mock_isfile):
|
||||
"""Test Candlepin CA returns None when configured path doesn't exist."""
|
||||
mock_settings.AWX_ANALYTICS_CANDLEPIN_CA = '/path/to/missing.pem'
|
||||
mock_isfile.return_value = False
|
||||
ca = get_candlepin_ca()
|
||||
assert ca is None
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.settings')
|
||||
def test_get_proxy_url_from_settings(self, mock_settings):
|
||||
"""Test proxy URL from Django settings."""
|
||||
mock_settings.AWX_ANALYTICS_CANDLEPIN_PROXY_URL = 'http://proxy.example.com:8080'
|
||||
proxy = get_proxy_url()
|
||||
assert proxy == 'http://proxy.example.com:8080'
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.x509.load_pem_x509_certificate')
|
||||
def test_parse_cert(self, mock_load_cert):
|
||||
"""Test certificate parsing."""
|
||||
# Mock a certificate object
|
||||
mock_cert = mock.Mock()
|
||||
mock_cert.serial_number = 123456
|
||||
mock_cert.not_valid_before_utc = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_cert.not_valid_after_utc = datetime(2027, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
# Mock subject and issuer
|
||||
mock_attr = mock.Mock()
|
||||
mock_attr.oid._name = 'commonName'
|
||||
mock_attr.value = 'test-cn'
|
||||
mock_cert.subject = [mock_attr]
|
||||
mock_cert.issuer = [mock_attr]
|
||||
|
||||
mock_load_cert.return_value = mock_cert
|
||||
|
||||
result = parse_cert('fake-pem')
|
||||
|
||||
assert result['serial'] == '123456'
|
||||
assert result['cn'] == 'test-cn'
|
||||
assert 'not_before' in result
|
||||
assert 'not_after' in result
|
||||
assert 'days_remaining' in result
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.parse_cert')
|
||||
def test_needs_renewal_true(self, mock_parse):
|
||||
"""Test needs_renewal returns True when cert is expiring soon."""
|
||||
mock_parse.return_value = {'days_remaining': 10}
|
||||
|
||||
result = needs_renewal('fake-cert', days_before_expiry=30)
|
||||
assert result is True
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.parse_cert')
|
||||
def test_needs_renewal_false(self, mock_parse):
|
||||
"""Test needs_renewal returns False when cert has time remaining."""
|
||||
mock_parse.return_value = {'days_remaining': 100}
|
||||
|
||||
result = needs_renewal('fake-cert', days_before_expiry=30)
|
||||
assert result is False
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.CandlepinClient')
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.parse_cert')
|
||||
def test_run_candlepin_lifecycle_no_renewal_needed(self, mock_parse, mock_client_class):
|
||||
"""Test lifecycle when no renewal is needed."""
|
||||
mock_parse.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2027-01-01T00:00:00+00:00', 'days_remaining': 100}
|
||||
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = True
|
||||
mock_client.get_consumer.return_value = None # Skip serial comparison
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
cert_pem, key_pem = run_candlepin_lifecycle('cert-pem', 'key-pem', 'consumer-uuid', candlepin_url='https://test.example.com', renewal_days=30)
|
||||
|
||||
assert cert_pem == 'cert-pem'
|
||||
assert key_pem == 'key-pem'
|
||||
mock_client.checkin.assert_called_once()
|
||||
mock_client.regenerate_cert.assert_not_called()
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.CandlepinClient')
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.parse_cert')
|
||||
def test_run_candlepin_lifecycle_with_renewal(self, mock_parse, mock_client_class):
|
||||
"""Test lifecycle when renewal is needed."""
|
||||
# parse_cert is called multiple times:
|
||||
# 1. Parse original cert
|
||||
# 2. In needs_renewal() to check expiry
|
||||
# 3. Parse new cert after renewal for logging
|
||||
mock_parse.side_effect = [
|
||||
{'serial': '123', 'cn': 'test', 'not_after': '2026-02-01', 'days_remaining': 10}, # Original cert
|
||||
{'serial': '123', 'cn': 'test', 'not_after': '2026-02-01', 'days_remaining': 10}, # needs_renewal check
|
||||
{'serial': '456', 'cn': 'test', 'not_after': '2027-02-01', 'days_remaining': 365}, # New cert
|
||||
]
|
||||
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = True
|
||||
mock_client.get_consumer.return_value = None # Skip serial comparison
|
||||
mock_client.regenerate_cert.return_value = ('new-cert', 'new-key')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
cert_pem, key_pem = run_candlepin_lifecycle('old-cert', 'old-key', 'consumer-uuid', renewal_days=90)
|
||||
|
||||
assert cert_pem == 'new-cert'
|
||||
assert key_pem == 'new-key'
|
||||
mock_client.regenerate_cert.assert_called_once()
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.CandlepinClient')
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.parse_cert')
|
||||
def test_run_candlepin_lifecycle_expired_cert_renewal(self, mock_parse, mock_client_class):
|
||||
"""Test lifecycle renews an expired certificate."""
|
||||
# parse_cert called for:
|
||||
# 1. Parse original expired cert
|
||||
# 2. needs_renewal check (expired, so returns True)
|
||||
# 3. Parse new cert after renewal
|
||||
mock_parse.side_effect = [
|
||||
{'serial': '123', 'cn': 'test', 'not_after': '2025-12-31', 'days_remaining': -120}, # Expired cert
|
||||
{'serial': '123', 'cn': 'test', 'not_after': '2025-12-31', 'days_remaining': -120}, # needs_renewal
|
||||
{'serial': '456', 'cn': 'test', 'not_after': '2027-06-01', 'days_remaining': 365}, # New cert
|
||||
]
|
||||
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = True
|
||||
mock_client.get_consumer.return_value = None
|
||||
mock_client.regenerate_cert.return_value = ('new-cert', 'new-key')
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
cert_pem, key_pem = run_candlepin_lifecycle('expired-cert', 'old-key', 'consumer-uuid', renewal_days=90)
|
||||
|
||||
assert cert_pem == 'new-cert'
|
||||
assert key_pem == 'new-key'
|
||||
mock_client.regenerate_cert.assert_called_once()
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.CandlepinClient')
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.parse_cert')
|
||||
def test_run_candlepin_lifecycle_checkin_failure_revoked_cert(self, mock_parse, mock_client_class):
|
||||
"""Test lifecycle handles check-in failure (e.g., revoked certificate)."""
|
||||
mock_parse.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 100}
|
||||
|
||||
# Check-in fails (could indicate revoked cert or deleted consumer)
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = False
|
||||
mock_client.get_consumer.return_value = None # get_consumer also fails
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Lifecycle should continue and return original cert
|
||||
cert_pem, key_pem = run_candlepin_lifecycle('cert-pem', 'key-pem', 'consumer-uuid', renewal_days=30)
|
||||
|
||||
assert cert_pem == 'cert-pem'
|
||||
assert key_pem == 'key-pem'
|
||||
mock_client.checkin.assert_called_once()
|
||||
# Regeneration should not be attempted since get_consumer indicates consumer doesn't exist
|
||||
mock_client.regenerate_cert.assert_not_called()
|
||||
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.CandlepinClient')
|
||||
@mock.patch('awx.main.utils.candlepin.lifecycle.parse_cert')
|
||||
def test_run_candlepin_lifecycle_consumer_deleted_server_side(self, mock_parse, mock_client_class):
|
||||
"""Test lifecycle detects when consumer was deleted from Candlepin server."""
|
||||
mock_parse.return_value = {'serial': '123', 'cn': 'test', 'not_after': '2027-01-01', 'days_remaining': 100}
|
||||
|
||||
# Both check-in and get_consumer fail (consumer deleted)
|
||||
mock_client = mock.Mock()
|
||||
mock_client.checkin.return_value = False
|
||||
mock_client.get_consumer.return_value = None
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
cert_pem, key_pem = run_candlepin_lifecycle('cert-pem', 'key-pem', 'consumer-uuid', renewal_days=30)
|
||||
|
||||
# Should return original cert (caller can attempt mTLS, which will fail and fall back to service account)
|
||||
assert cert_pem == 'cert-pem'
|
||||
assert key_pem == 'key-pem'
|
||||
mock_client.checkin.assert_called_once()
|
||||
mock_client.get_consumer.assert_called_once()
|
||||
mock_client.regenerate_cert.assert_not_called()
|
||||
@@ -7,7 +7,7 @@ from django.utils.timezone import now
|
||||
from awx.main.models.schedules import _fast_forward_rrule, Schedule
|
||||
from dateutil.rrule import HOURLY, MINUTELY, MONTHLY
|
||||
|
||||
REF_DT = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
|
||||
REF_DT = datetime.datetime(2026, 4, 16, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -20,6 +20,10 @@ REF_DT = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
|
||||
'DTSTART;TZID=America/New_York:20201118T200000 RRULE:FREQ=MINUTELY;INTERVAL=5;WKST=SU;BYMONTH=2,3;BYMONTHDAY=18;BYHOUR=5;BYMINUTE=35;BYSECOND=0',
|
||||
id='every-5-minutes-at-5:35:00-am-on-the-18th-day-of-feb-or-march-with-week-starting-on-sundays',
|
||||
),
|
||||
pytest.param(
|
||||
'DTSTART;TZID=America/New_York:20251211T130000 RRULE:FREQ=HOURLY;INTERVAL=4;WKST=MO;BYDAY=MO,TU,WE,TH,FR;BYHOUR=1,5,9,13,17,21;BYMINUTE=0',
|
||||
id='every-4-hours-at-1-5-9-13-17-21-am-on-monday-through-friday-with-week-starting-on-monday',
|
||||
),
|
||||
pytest.param(
|
||||
'DTSTART;TZID=America/New_York:20201118T200000 RRULE:FREQ=HOURLY;INTERVAL=5;WKST=SU;BYMONTH=2,3;BYHOUR=5',
|
||||
id='every-5-hours-at-5-am-in-feb-or-march-with-week-starting-on-sundays',
|
||||
@@ -48,6 +52,7 @@ def test_fast_forwarded_rrule_matches_original_occurrence(rrulestr):
|
||||
[
|
||||
pytest.param(datetime.datetime(2024, 12, 1, 0, 0, tzinfo=datetime.timezone.utc), id='ref-dt-out-of-dst'),
|
||||
pytest.param(datetime.datetime(2024, 6, 1, 0, 0, tzinfo=datetime.timezone.utc), id='ref-dt-in-dst'),
|
||||
pytest.param(datetime.datetime(2024, 11, 3, 6, 30, tzinfo=datetime.timezone.utc), id='ref-dt-fall-back-day'),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -58,6 +63,8 @@ def test_fast_forwarded_rrule_matches_original_occurrence(rrulestr):
|
||||
pytest.param(
|
||||
'DTSTART;TZID=Europe/Lisbon:20230703T005800 RRULE:INTERVAL=10;FREQ=MINUTELY;BYHOUR=9,10,11,12,13,14,15,16,17,18,19,20,21', id='rrule-in-dst-by-hour'
|
||||
),
|
||||
pytest.param('DTSTART;TZID=America/New_York:20230313T005800 RRULE:FREQ=MINUTELY;INTERVAL=7', id='rrule-post-dst-7min'),
|
||||
pytest.param('DTSTART;TZID=America/New_York:20230313T005800 RRULE:FREQ=MINUTELY;INTERVAL=13', id='rrule-post-dst-13min'),
|
||||
],
|
||||
)
|
||||
def test_fast_forward_across_dst(rrulestr, ref_dt):
|
||||
|
||||
349
awx/main/utils/candlepin/__init__.py
Normal file
349
awx/main/utils/candlepin/__init__.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# Copyright (c) 2026 Ansible, Inc.
|
||||
# All Rights Reserved.
|
||||
|
||||
"""
|
||||
Candlepin integration for mTLS-based authentication.
|
||||
|
||||
This package provides Candlepin consumer identity certificate support,
|
||||
enabling AAP controller instances to authenticate analytics uploads using
|
||||
mTLS instead of service account credentials.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from .client import CandlepinClient
|
||||
from .lifecycle import (
|
||||
get_candlepin_ca,
|
||||
get_candlepin_url,
|
||||
get_proxy_url,
|
||||
get_renewal_days,
|
||||
is_cert_valid,
|
||||
parse_cert,
|
||||
run_candlepin_lifecycle,
|
||||
)
|
||||
|
||||
logger = logging.getLogger('awx.main.utils.candlepin')
|
||||
|
||||
|
||||
def _fetch_candlepin_cert_from_db():
|
||||
"""Read cert PEM, key PEM, and consumer UUID from AWX conf_settings.
|
||||
|
||||
Returns (cert_pem, key_pem, consumer_uuid) if valid certificate data exists,
|
||||
or (None, None, None) if placeholder/unregistered data.
|
||||
Best-effort: failures are logged as warnings and never propagate.
|
||||
"""
|
||||
try:
|
||||
consumer_uuid = getattr(settings, 'CANDLEPIN_CONSUMER_UUID', '')
|
||||
cert_pem = getattr(settings, 'CANDLEPIN_CERT_PEM', '')
|
||||
key_pem = getattr(settings, 'CANDLEPIN_KEY_PEM', '')
|
||||
|
||||
# Check if we have valid data
|
||||
if not consumer_uuid or not cert_pem or not key_pem:
|
||||
return None, None, None
|
||||
|
||||
return cert_pem, key_pem, consumer_uuid
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not fetch Candlepin lifecycle data from settings: {e}')
|
||||
return None, None, None
|
||||
|
||||
|
||||
def _save_candlepin_cert_to_db(cert_pem, key_pem):
|
||||
"""Persist a renewed Candlepin identity cert and key to AWX conf_settings.
|
||||
|
||||
Returns:
|
||||
bool: True if save succeeded, False on any error.
|
||||
"""
|
||||
try:
|
||||
# Parse certificate to extract metadata
|
||||
try:
|
||||
cert_info = parse_cert(cert_pem)
|
||||
serial_number = cert_info.get('serial', '')
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not parse certificate metadata: {e}')
|
||||
serial_number = ''
|
||||
|
||||
# Update conf_settings via settings wrapper
|
||||
settings.CANDLEPIN_CERT_PEM = cert_pem
|
||||
settings.CANDLEPIN_KEY_PEM = key_pem
|
||||
settings.CANDLEPIN_SERIAL_NUMBER = serial_number
|
||||
|
||||
logger.info('Renewed Candlepin cert and key saved to conf_settings.')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Could not save renewed Candlepin cert to conf_settings: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def _discover_org(candlepin_url, username, password, verify_tls=True):
|
||||
"""Discover org key via GET /users/{username}/owners.
|
||||
|
||||
Args:
|
||||
candlepin_url: Candlepin base URL
|
||||
username: Username for authentication
|
||||
password: Password for authentication
|
||||
verify_tls: Whether to verify TLS certificates (default: True)
|
||||
|
||||
Returns:
|
||||
str: Organization key if found, None on any failure.
|
||||
"""
|
||||
try:
|
||||
url = f"{candlepin_url}/users/{username}/owners"
|
||||
if verify_tls:
|
||||
candlepin_ca = get_candlepin_ca()
|
||||
verify = candlepin_ca if candlepin_ca else True
|
||||
else:
|
||||
verify = False
|
||||
|
||||
resp = requests.get(url, auth=(username, password), verify=verify, timeout=30)
|
||||
resp.raise_for_status()
|
||||
|
||||
owners = resp.json()
|
||||
if not owners:
|
||||
logger.warning(f'No organizations found for user {username}')
|
||||
return None
|
||||
|
||||
# Pick the first org, but warn if multiple exist
|
||||
if len(owners) > 1:
|
||||
logger.warning(f'User {username} has access to {len(owners)} organizations. Using first: {owners[0]}')
|
||||
first_org = owners[0]
|
||||
org = first_org.get('key')
|
||||
if not org:
|
||||
logger.warning(f'Organization key missing in first org entry for user {username}')
|
||||
return None
|
||||
|
||||
return org
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f'Failed to discover organization for user {username}: {e}')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f'Unexpected error discovering organization for user {username}: {e}')
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_registration_credentials_from_db(verify_tls=True):
|
||||
"""Read Candlepin registration credentials from AWX settings.
|
||||
|
||||
Tries several options to retrieve the Candlepin credentials (set by AWX when the
|
||||
customer configures their Red Hat subscription), and to discover the org (org
|
||||
key for the Candlepin /consumers endpoint), and INSTALL_UUID (used as the
|
||||
consumer's aap.instance_uuid fact).
|
||||
|
||||
Priority for authentication credentials:
|
||||
- If both REDHAT_USERNAME and SUBSCRIPTIONS_USERNAME exist: use REDHAT_USERNAME
|
||||
- If only SUBSCRIPTIONS_USERNAME exists: use SUBSCRIPTIONS_USERNAME
|
||||
|
||||
Args:
|
||||
verify_tls: Whether to verify TLS certificates during org discovery (default: True)
|
||||
|
||||
Returns (username, password, org, install_uuid), any of which may be None
|
||||
if the corresponding setting is not configured.
|
||||
"""
|
||||
candlepin_url = get_candlepin_url()
|
||||
try:
|
||||
username = getattr(settings, 'REDHAT_USERNAME', None)
|
||||
password = getattr(settings, 'REDHAT_PASSWORD', None)
|
||||
|
||||
if not (username and password):
|
||||
username = getattr(settings, 'SUBSCRIPTIONS_USERNAME', None)
|
||||
password = getattr(settings, 'SUBSCRIPTIONS_PASSWORD', None)
|
||||
|
||||
install_uuid = getattr(settings, 'INSTALL_UUID', None)
|
||||
|
||||
org = _discover_org(candlepin_url, username, password, verify_tls=verify_tls) if username and password else None
|
||||
|
||||
return username, password, org, install_uuid
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not fetch Candlepin registration credentials from settings: {e}')
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
def resolve_registration_credentials(username_override=None, password_override=None, org_override=None, verify_tls=True):
|
||||
"""Resolve Candlepin registration credentials with optional overrides.
|
||||
|
||||
Fetches credentials from database settings and merges with any provided overrides.
|
||||
Validates that all required fields are present.
|
||||
|
||||
Args:
|
||||
username_override: Optional username to use instead of database value
|
||||
password_override: Optional password to use instead of database value
|
||||
org_override: Optional org to use instead of auto-discovered value
|
||||
verify_tls: Whether to verify TLS certificates during org discovery (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple (username, password, org, install_uuid) if all required fields present,
|
||||
or (None, None, None, None, error_messages) if validation fails.
|
||||
error_messages is a list of strings describing missing values.
|
||||
"""
|
||||
db_username, db_password, db_org, db_install_uuid = _fetch_registration_credentials_from_db(verify_tls=verify_tls)
|
||||
|
||||
username = username_override or db_username
|
||||
password = password_override or db_password
|
||||
org = org_override or db_org
|
||||
|
||||
# Validate all required fields are present
|
||||
missing = []
|
||||
if not username:
|
||||
missing.append('username (provide --username or set REDHAT_USERNAME in database)')
|
||||
if not password:
|
||||
missing.append('password (provide password or set REDHAT_PASSWORD in database)')
|
||||
if not org:
|
||||
missing.append('org (provide --org or ensure SUBSCRIPTIONS_USERNAME/PASSWORD are configured for auto-discovery)')
|
||||
|
||||
if missing:
|
||||
return None, None, None, None, missing
|
||||
|
||||
return username, password, org, db_install_uuid, None
|
||||
|
||||
|
||||
def _save_candlepin_registration_to_db(cert_pem, key_pem, consumer_uuid):
|
||||
"""Persist a new Candlepin consumer registration (cert, key, UUID) to AWX conf_settings.
|
||||
|
||||
Returns:
|
||||
bool: True if save succeeded, False on any error.
|
||||
"""
|
||||
try:
|
||||
# Parse certificate to extract metadata
|
||||
try:
|
||||
cert_info = parse_cert(cert_pem)
|
||||
serial_number = cert_info.get('serial', '')
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not parse certificate metadata: {e}')
|
||||
serial_number = ''
|
||||
|
||||
# Update conf_settings with all registration data via settings wrapper
|
||||
settings.CANDLEPIN_CONSUMER_UUID = consumer_uuid
|
||||
settings.CANDLEPIN_CERT_PEM = cert_pem
|
||||
settings.CANDLEPIN_KEY_PEM = key_pem
|
||||
settings.CANDLEPIN_SERIAL_NUMBER = serial_number
|
||||
|
||||
logger.info(f'Candlepin consumer registration saved to conf_settings (uuid={consumer_uuid}).')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Could not save Candlepin registration to conf_settings: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def _register_candlepin_consumer():
|
||||
"""Register a new Candlepin consumer using credentials from AWX settings.
|
||||
|
||||
Called when no identity cert exists in the DB.
|
||||
|
||||
Reads the Candlepin credentials and the org key and then calls
|
||||
POST /consumers on Candlepin to obtain an identity certificate.
|
||||
On success the cert, key, and consumer UUID are persisted to conf_settings.
|
||||
|
||||
Returns (cert_pem, key_pem, consumer_uuid) on success, (None, None, None) on
|
||||
any failure. Best-effort: logs errors but never propagates.
|
||||
"""
|
||||
username, password, org, install_uuid = _fetch_registration_credentials_from_db()
|
||||
|
||||
if not username or not password:
|
||||
logger.warning('Candlepin registration is enabled but credentials are not set; skipping registration.')
|
||||
return None, None, None
|
||||
|
||||
if not org:
|
||||
logger.warning('Candlepin registration is enabled but subscription org is not available; skipping registration.')
|
||||
return None, None, None
|
||||
|
||||
candlepin_url = get_candlepin_url()
|
||||
candlepin_ca = get_candlepin_ca()
|
||||
proxy = get_proxy_url()
|
||||
client = CandlepinClient(base_url=candlepin_url, candlepin_ca=candlepin_ca, proxy=proxy)
|
||||
|
||||
try:
|
||||
cert_pem, key_pem, consumer_uuid = client.register_consumer(username, password, org, install_uuid)
|
||||
except Exception as e:
|
||||
logger.error(f'Candlepin consumer registration failed: {e}')
|
||||
return None, None, None
|
||||
|
||||
if not _save_candlepin_registration_to_db(cert_pem, key_pem, consumer_uuid):
|
||||
logger.error('Candlepin consumer registration succeeded but failed to save to database.')
|
||||
return None, None, None
|
||||
return cert_pem, key_pem, consumer_uuid
|
||||
|
||||
|
||||
def _run_candlepin_lifecycle(cert_pem, key_pem, consumer_uuid):
|
||||
"""Orchestrate Candlepin check-in and proactive cert renewal.
|
||||
|
||||
Returns the (possibly renewed) (cert_pem, key_pem) tuple. If renewal fails, the
|
||||
original cert is returned and the caller will validate it with is_cert_valid().
|
||||
If invalid, the caller skips mTLS and falls back directly to OIDC authentication.
|
||||
"""
|
||||
if not consumer_uuid:
|
||||
logger.warning('Candlepin lifecycle is enabled but consumer UUID is not set; skipping check-in and renewal.')
|
||||
return cert_pem, key_pem
|
||||
|
||||
candlepin_url = get_candlepin_url()
|
||||
renewal_days = get_renewal_days()
|
||||
candlepin_ca = get_candlepin_ca()
|
||||
proxy = get_proxy_url()
|
||||
|
||||
try:
|
||||
new_cert_pem, new_key_pem = run_candlepin_lifecycle(
|
||||
cert_pem,
|
||||
key_pem,
|
||||
consumer_uuid,
|
||||
candlepin_url=candlepin_url,
|
||||
renewal_days=renewal_days,
|
||||
candlepin_ca=candlepin_ca,
|
||||
proxy=proxy,
|
||||
)
|
||||
if (new_cert_pem, new_key_pem) != (cert_pem, key_pem):
|
||||
if not _save_candlepin_cert_to_db(new_cert_pem, new_key_pem):
|
||||
logger.warning('Renewed certificate will be used for this request, but failed to persist to database for future use.')
|
||||
return new_cert_pem, new_key_pem
|
||||
except Exception as e:
|
||||
logger.error(f'Candlepin lifecycle (check-in / renewal) failed: {e}; will attempt mTLS with existing cert')
|
||||
return cert_pem, key_pem
|
||||
|
||||
|
||||
def get_or_generate_candlepin_certificate():
|
||||
"""
|
||||
Get or generate Candlepin certificate for analytics authentication.
|
||||
|
||||
This function provides certificate-based authentication for analytics uploads.
|
||||
It will:
|
||||
1. Check for existing certificate in conf_settings
|
||||
2. If missing, attempt to register with Candlepin (credentials from settings)
|
||||
3. If exists, check for renewal needs and refresh if needed
|
||||
4. Return the certificate and key as PEM strings
|
||||
|
||||
Returns:
|
||||
Tuple (cert_pem, key_pem) as strings if certificate is available, (None, None) otherwise.
|
||||
|
||||
Note:
|
||||
Credentials for registration are retrieved from Django settings internally
|
||||
(REDHAT_USERNAME/PASSWORD, SUBSCRIPTIONS_USERNAME/PASSWORD, or
|
||||
SUBSCRIPTIONS_CLIENT_ID/CLIENT_SECRET in priority order).
|
||||
"""
|
||||
cert_pem, key_pem, consumer_uuid = _fetch_candlepin_cert_from_db()
|
||||
|
||||
# If no certificate exists, attempt registration
|
||||
if not cert_pem or not key_pem:
|
||||
logger.info('No Candlepin certificate found, attempting registration')
|
||||
cert_pem, key_pem, consumer_uuid = _register_candlepin_consumer()
|
||||
|
||||
if not cert_pem or not key_pem:
|
||||
logger.debug('Candlepin certificate registration failed or not configured')
|
||||
return None, None
|
||||
|
||||
# Run lifecycle (check-in and renewal if needed)
|
||||
if consumer_uuid:
|
||||
cert_pem, key_pem = _run_candlepin_lifecycle(cert_pem, key_pem, consumer_uuid)
|
||||
|
||||
# Validate certificate is still usable
|
||||
if not is_cert_valid(cert_pem):
|
||||
logger.warning('Candlepin certificate is not valid (expired or not yet valid)')
|
||||
return None, None
|
||||
|
||||
# Return raw PEM strings - caller will create temp files if needed
|
||||
return cert_pem, key_pem
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_or_generate_candlepin_certificate',
|
||||
'resolve_registration_credentials',
|
||||
]
|
||||
258
awx/main/utils/candlepin/client.py
Normal file
258
awx/main/utils/candlepin/client.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import os
|
||||
import tempfile
|
||||
import uuid as _uuid_mod
|
||||
from datetime import datetime, timezone
|
||||
import requests
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('awx.main.utils.candlepin')
|
||||
|
||||
|
||||
class _temp_cert_files:
|
||||
"""
|
||||
Context manager: writes cert + key to secure temp files, auto-deletes on exit.
|
||||
|
||||
Uses NamedTemporaryFile with delete=True for better cleanup on process termination.
|
||||
Files are unlinked immediately on Unix systems, providing better security against
|
||||
orphaned private keys in /tmp.
|
||||
"""
|
||||
|
||||
def __init__(self, cert_pem, key_pem):
|
||||
self._cert_pem = cert_pem
|
||||
self._key_pem = key_pem
|
||||
self._cert_file = None
|
||||
self._key_file = None
|
||||
|
||||
def __enter__(self):
|
||||
try:
|
||||
# Create temp file for certificate
|
||||
self._cert_file = tempfile.NamedTemporaryFile(mode='w', prefix='candlepin_cert_', suffix='.pem', delete=True)
|
||||
self._cert_file.write(self._cert_pem)
|
||||
self._cert_file.flush()
|
||||
os.chmod(self._cert_file.name, 0o600)
|
||||
|
||||
# Create temp file for private key
|
||||
self._key_file = tempfile.NamedTemporaryFile(mode='w', prefix='candlepin_key_', suffix='.pem', delete=True)
|
||||
self._key_file.write(self._key_pem)
|
||||
self._key_file.flush()
|
||||
os.chmod(self._key_file.name, 0o600)
|
||||
|
||||
return self._cert_file.name, self._key_file.name
|
||||
except Exception:
|
||||
# Clean up on error
|
||||
if self._cert_file:
|
||||
self._cert_file.close()
|
||||
if self._key_file:
|
||||
self._key_file.close()
|
||||
raise
|
||||
|
||||
def __exit__(self, *_):
|
||||
# Closing NamedTemporaryFile automatically deletes it
|
||||
if self._cert_file:
|
||||
try:
|
||||
self._cert_file.close()
|
||||
except Exception as e:
|
||||
logger.warning(f'Error closing cert temp file: {e}')
|
||||
if self._key_file:
|
||||
try:
|
||||
self._key_file.close()
|
||||
except Exception as e:
|
||||
logger.warning(f'Error closing key temp file: {e}')
|
||||
|
||||
|
||||
class CandlepinClient:
|
||||
"""
|
||||
Minimal Candlepin REST client for certificate lifecycle operations.
|
||||
|
||||
All API calls authenticate with the consumer identity certificate (mTLS),
|
||||
matching the pattern used by subscription-manager after initial registration.
|
||||
|
||||
TLS server verification is **enabled** by default (``verify_tls=True``).
|
||||
Pass ``candlepin_ca`` to verify against a specific CA bundle rather than the
|
||||
system trust store. Verification can only be disabled by explicitly passing
|
||||
``verify_tls=False``; this should be used only in controlled test environments
|
||||
and never in production.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url, candlepin_ca=None, proxy=None, verify_tls=True):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
if candlepin_ca:
|
||||
self.verify = candlepin_ca
|
||||
elif verify_tls:
|
||||
self.verify = True
|
||||
else:
|
||||
# Explicit opt-in required to reach this branch — never set by default.
|
||||
logger.warning('CandlepinClient: TLS verification is DISABLED (verify_tls=False). Do not use in production.')
|
||||
self.verify = False
|
||||
if proxy:
|
||||
# Use the caller-supplied URL as-is for HTTPS targets (preserves the
|
||||
# intended scheme — usually http:// so requests uses plain HTTP to reach
|
||||
# the proxy and issues CONNECT for TLS tunneling, but https:// is also
|
||||
# accepted for the rare case of an HTTPS-fronted proxy).
|
||||
# The http:// key always uses plain HTTP since non-TLS traffic never
|
||||
# needs TLS to the proxy itself.
|
||||
host = proxy.split('://', 1)[-1]
|
||||
self.proxies = {'https': proxy, 'http': f'http://{host}'}
|
||||
else:
|
||||
self.proxies = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def register_consumer(self, username, password, org, install_uuid=None):
|
||||
"""POST /consumers?owner={org} — register a new AAP consumer with basic auth.
|
||||
|
||||
Uses the customer's Red Hat subscription credentials (REDHAT_USERNAME /
|
||||
REDHAT_PASSWORD from AWX conf_setting) to register this controller
|
||||
instance as a Candlepin consumer and obtain an identity certificate for mTLS.
|
||||
|
||||
Args:
|
||||
username: Red Hat subscription username (from REDHAT_USERNAME).
|
||||
password: Red Hat subscription password (from REDHAT_PASSWORD).
|
||||
org: Candlepin owner/org key (retrieved with subscription credentials).
|
||||
install_uuid: AWX INSTALL_UUID used as the consumer's aap.instance_uuid
|
||||
fact; falls back to a random UUID if not provided.
|
||||
|
||||
Returns:
|
||||
Tuple ``(cert_pem, key_pem, consumer_uuid)``.
|
||||
|
||||
Raises:
|
||||
RuntimeError on any network or API failure.
|
||||
"""
|
||||
url = f'{self.base_url}/consumers'
|
||||
instance_uuid = install_uuid or str(_uuid_mod.uuid4())
|
||||
payload = {
|
||||
'name': f'aap-{instance_uuid[:8]}',
|
||||
'type': {'label': 'aap'},
|
||||
'facts': {
|
||||
'system.certificate_version': '3.3',
|
||||
'system.name': 'aap-controller',
|
||||
'aap.instance_uuid': instance_uuid,
|
||||
},
|
||||
}
|
||||
try:
|
||||
resp = requests.post(
|
||||
url,
|
||||
params={'owner': org},
|
||||
auth=(username, password),
|
||||
json=payload,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
verify=self.verify,
|
||||
proxies=self.proxies,
|
||||
timeout=120,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Candlepin register_consumer network error: {e}') from e
|
||||
|
||||
if not resp.ok:
|
||||
raise RuntimeError(f'Candlepin register_consumer failed with status {resp.status_code}: {resp.text}')
|
||||
|
||||
try:
|
||||
body = resp.json()
|
||||
consumer_uuid = body.get('uuid')
|
||||
id_cert = body.get('idCert', {})
|
||||
cert_pem = id_cert.get('cert')
|
||||
key_pem = id_cert.get('key')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Candlepin register_consumer: could not parse response JSON: {e}') from e
|
||||
|
||||
if not consumer_uuid or not cert_pem or not key_pem:
|
||||
raise RuntimeError('Candlepin register_consumer: response missing uuid, idCert.cert or idCert.key')
|
||||
|
||||
logger.info(f'Candlepin consumer registered successfully (uuid={consumer_uuid})')
|
||||
return cert_pem, key_pem, consumer_uuid
|
||||
|
||||
def get_consumer(self, consumer_uuid, cert_pem, key_pem):
|
||||
"""GET /consumers/{uuid} — retrieve consumer information from server.
|
||||
|
||||
Best-effort: logs a warning on failure but never raises.
|
||||
|
||||
Returns:
|
||||
Dict with consumer data (including 'idCert' with serial) on success,
|
||||
None on any failure.
|
||||
"""
|
||||
url = f'{self.base_url}/consumers/{consumer_uuid}'
|
||||
try:
|
||||
with _temp_cert_files(cert_pem, key_pem) as (cert_path, key_path):
|
||||
resp = requests.get(
|
||||
url,
|
||||
cert=(cert_path, key_path),
|
||||
verify=self.verify,
|
||||
proxies=self.proxies,
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
logger.debug(f'Candlepin get_consumer successful for consumer {consumer_uuid}')
|
||||
return resp.json()
|
||||
logger.warning(f'Candlepin get_consumer returned unexpected status {resp.status_code} for consumer {consumer_uuid}')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f'Candlepin get_consumer failed for consumer {consumer_uuid}: {e}')
|
||||
return None
|
||||
|
||||
def checkin(self, consumer_uuid, cert_pem, key_pem):
|
||||
"""PUT /consumers/{uuid} — reset inactivity timer.
|
||||
|
||||
Best-effort: logs a warning on failure but never raises so that a
|
||||
transient Candlepin outage cannot abort a gather run.
|
||||
|
||||
Returns True on success, False on any failure.
|
||||
"""
|
||||
url = f'{self.base_url}/consumers/{consumer_uuid}'
|
||||
try:
|
||||
with _temp_cert_files(cert_pem, key_pem) as (cert_path, key_path):
|
||||
resp = requests.put(
|
||||
url,
|
||||
cert=(cert_path, key_path),
|
||||
json={'facts': {'aap.last_checkin': datetime.now(timezone.utc).isoformat()}},
|
||||
headers={'Content-Type': 'application/json'},
|
||||
verify=self.verify,
|
||||
proxies=self.proxies,
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 204):
|
||||
logger.info(f'Candlepin check-in successful for consumer {consumer_uuid}')
|
||||
return True
|
||||
logger.warning(f'Candlepin check-in returned unexpected status {resp.status_code} for consumer {consumer_uuid}')
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f'Candlepin check-in failed for consumer {consumer_uuid}: {e}')
|
||||
return False
|
||||
|
||||
def regenerate_cert(self, consumer_uuid, cert_pem, key_pem):
|
||||
"""POST /consumers/{uuid} — regenerate the identity certificate.
|
||||
|
||||
Returns ``(new_cert_pem, new_key_pem)`` on success.
|
||||
Raises ``RuntimeError`` on API or parsing failure so the caller can
|
||||
decide whether to fall back to service-account auth.
|
||||
"""
|
||||
url = f'{self.base_url}/consumers/{consumer_uuid}'
|
||||
with _temp_cert_files(cert_pem, key_pem) as (cert_path, key_path):
|
||||
try:
|
||||
resp = requests.post(
|
||||
url,
|
||||
cert=(cert_path, key_path),
|
||||
verify=self.verify,
|
||||
proxies=self.proxies,
|
||||
timeout=120,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Candlepin regenerate_cert network error for consumer {consumer_uuid}: {e}') from e
|
||||
|
||||
if not resp.ok:
|
||||
raise RuntimeError(f'Candlepin regenerate_cert failed with status {resp.status_code} for consumer {consumer_uuid}: {resp.text}')
|
||||
|
||||
try:
|
||||
body = resp.json()
|
||||
id_cert = body.get('idCert', {})
|
||||
new_cert_pem = id_cert.get('cert')
|
||||
new_key_pem = id_cert.get('key')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Candlepin regenerate_cert: could not parse response JSON: {e}') from e
|
||||
|
||||
if not new_cert_pem or not new_key_pem:
|
||||
raise RuntimeError(f'Candlepin regenerate_cert: response did not contain idCert.cert / idCert.key for consumer {consumer_uuid}')
|
||||
|
||||
logger.info(f'Candlepin cert regenerated successfully for consumer {consumer_uuid}')
|
||||
return new_cert_pem, new_key_pem
|
||||
221
awx/main/utils/candlepin/lifecycle.py
Normal file
221
awx/main/utils/candlepin/lifecycle.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Candlepin certificate lifecycle helpers.
|
||||
|
||||
is_cert_valid — quick parseable/non-expired guard used at ship time
|
||||
parse_cert — extract metadata from a PEM cert string
|
||||
needs_renewal — check whether a cert is within the renewal window
|
||||
run_candlepin_lifecycle — orchestrate check-in + proactive renewal per gather run
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from cryptography import x509
|
||||
from django.conf import settings
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('awx.main.utils.candlepin')
|
||||
|
||||
from .client import CandlepinClient
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Certificate helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_cert(pem_text):
|
||||
"""Parse a PEM certificate and return a metadata dict.
|
||||
|
||||
Returns a dict with keys: serial, cn, issuer_cn, issuer_org,
|
||||
not_before, not_after, days_remaining, validity_days.
|
||||
|
||||
Raises ``ValueError`` if the PEM cannot be parsed.
|
||||
"""
|
||||
data = pem_text.encode('utf-8') if isinstance(pem_text, str) else pem_text
|
||||
try:
|
||||
cert = x509.load_pem_x509_certificate(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f'Could not parse PEM certificate: {e}') from e
|
||||
|
||||
expiry = cert.not_valid_after_utc
|
||||
remaining = expiry - datetime.now(timezone.utc)
|
||||
|
||||
subject = {attr.oid._name: attr.value for attr in cert.subject}
|
||||
issuer = {attr.oid._name: attr.value for attr in cert.issuer}
|
||||
|
||||
return {
|
||||
'serial': str(cert.serial_number),
|
||||
'cn': subject.get('commonName', 'unknown'),
|
||||
'issuer_cn': issuer.get('commonName', 'unknown'),
|
||||
'issuer_org': issuer.get('organizationName', 'unknown'),
|
||||
'not_before': cert.not_valid_before_utc.isoformat(),
|
||||
'not_after': expiry.isoformat(),
|
||||
'days_remaining': remaining.days,
|
||||
'validity_days': (expiry - cert.not_valid_before_utc).days,
|
||||
}
|
||||
|
||||
|
||||
def is_cert_valid(cert_pem: str) -> bool:
|
||||
"""Return True if cert_pem is parseable, already valid, and not yet expired.
|
||||
|
||||
Logs a warning (suitable for operator visibility) when the cert is not yet
|
||||
valid, expired, or unparseable, then returns False so the caller can fall
|
||||
back to service-account authentication.
|
||||
"""
|
||||
try:
|
||||
info = parse_cert(cert_pem)
|
||||
now = datetime.now(timezone.utc)
|
||||
not_before = datetime.fromisoformat(info['not_before'])
|
||||
if now < not_before:
|
||||
logger.warning(f'Candlepin cert is not yet valid (not_before={info["not_before"]}); falling back to service account auth')
|
||||
return False
|
||||
if info['days_remaining'] < 0:
|
||||
logger.warning(f'Candlepin cert expired at {info["not_after"]}; falling back to service account auth')
|
||||
return False
|
||||
return True
|
||||
except ValueError as e:
|
||||
logger.warning(f'Could not parse Candlepin cert: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def needs_renewal(pem_text, days_before_expiry):
|
||||
"""Return True if the cert expires within ``days_before_expiry`` days.
|
||||
|
||||
Also returns True if the cert is already expired (days_remaining < 0).
|
||||
Raises ``ValueError`` if the PEM cannot be parsed.
|
||||
"""
|
||||
info = parse_cert(pem_text)
|
||||
return info['days_remaining'] <= days_before_expiry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lifecycle orchestration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_candlepin_lifecycle(cert_pem, key_pem, consumer_uuid, *, candlepin_url=None, renewal_days=90, candlepin_ca=None, proxy=None):
|
||||
"""Perform check-in and, if needed, proactive cert renewal.
|
||||
|
||||
Called once per gather run. Returns ``(cert_pem, key_pem)`` — either
|
||||
the originals (if no renewal was needed) or the freshly regenerated pair.
|
||||
|
||||
Args:
|
||||
cert_pem: Consumer identity certificate PEM string.
|
||||
key_pem: Consumer identity key PEM string.
|
||||
consumer_uuid: Candlepin consumer UUID string.
|
||||
candlepin_url: Candlepin base URL (defaults to prod).
|
||||
renewal_days: Renew if expiry is within this many days (default 90).
|
||||
candlepin_ca: Path to Candlepin CA cert for server verification
|
||||
(default None → uses system trust store).
|
||||
proxy: Optional HTTP/HTTPS proxy URL string.
|
||||
|
||||
Returns:
|
||||
Tuple ``(cert_pem, key_pem)`` — possibly updated after renewal.
|
||||
|
||||
Raises:
|
||||
RuntimeError if cert regeneration is attempted and fails.
|
||||
"""
|
||||
client = CandlepinClient(base_url=candlepin_url, candlepin_ca=candlepin_ca, proxy=proxy)
|
||||
|
||||
# Step 1: Inspect cert metadata for diagnostics and renewal decision.
|
||||
try:
|
||||
info = parse_cert(cert_pem)
|
||||
except ValueError as e:
|
||||
logger.warning(f'Candlepin lifecycle: could not parse cert, skipping lifecycle: {e}')
|
||||
return cert_pem, key_pem
|
||||
|
||||
logger.info(f'Candlepin cert: serial={info["serial"]}, CN={info["cn"]}, expires={info["not_after"]}, days_remaining={info["days_remaining"]}')
|
||||
|
||||
# Step 2: Check-in (best-effort, never raises).
|
||||
checkin_success = client.checkin(consumer_uuid, cert_pem, key_pem)
|
||||
if not checkin_success:
|
||||
logger.warning(
|
||||
f'Candlepin check-in failed for consumer {consumer_uuid}. '
|
||||
f'Consumer may have been deleted server-side or certificate is invalid. '
|
||||
f'Lifecycle will continue but may fail.'
|
||||
)
|
||||
|
||||
# Step 3: Compare local cert serial with server's serial.
|
||||
# If they differ, the server has issued a new cert (e.g., admin regenerated it).
|
||||
consumer_data = client.get_consumer(consumer_uuid, cert_pem, key_pem)
|
||||
if not consumer_data:
|
||||
if not checkin_success:
|
||||
logger.error(
|
||||
f'Both check-in and get_consumer failed for consumer {consumer_uuid}. '
|
||||
f'Consumer was likely deleted from Candlepin server. '
|
||||
f'Re-registration may be required. Will attempt cert renewal anyway.'
|
||||
)
|
||||
else:
|
||||
logger.warning(f'Could not retrieve consumer data for {consumer_uuid} but check-in succeeded. Continuing lifecycle.')
|
||||
else:
|
||||
server_cert_pem = consumer_data.get('idCert', {}).get('cert')
|
||||
if server_cert_pem:
|
||||
try:
|
||||
server_info = parse_cert(server_cert_pem)
|
||||
server_serial = server_info['serial']
|
||||
local_serial = info['serial']
|
||||
|
||||
if server_serial != local_serial:
|
||||
logger.warning(
|
||||
f'Candlepin cert serial mismatch: local={local_serial}, server={server_serial}. '
|
||||
f'Server has issued a new certificate; requesting updated cert.'
|
||||
)
|
||||
# Fetch the new cert from the server
|
||||
new_cert_pem, new_key_pem = client.regenerate_cert(consumer_uuid, cert_pem, key_pem)
|
||||
|
||||
try:
|
||||
new_info = parse_cert(new_cert_pem)
|
||||
logger.info(f'Candlepin cert updated: old serial={local_serial}, new serial={new_info["serial"]}, new expiry={new_info["not_after"]}')
|
||||
except ValueError:
|
||||
logger.warning('Candlepin lifecycle: could not parse updated cert for logging')
|
||||
|
||||
return new_cert_pem, new_key_pem
|
||||
else:
|
||||
logger.debug(f'Candlepin cert serial matches server: {local_serial}')
|
||||
except ValueError as e:
|
||||
logger.warning(f'Candlepin lifecycle: could not parse server cert from get_consumer: {e}')
|
||||
|
||||
# Step 4: Proactive renewal if within the renewal window (or already expired).
|
||||
if needs_renewal(cert_pem, renewal_days):
|
||||
logger.info(f'Candlepin cert expires in {info["days_remaining"]} days (threshold: {renewal_days}); requesting renewal for consumer {consumer_uuid}')
|
||||
new_cert_pem, new_key_pem = client.regenerate_cert(consumer_uuid, cert_pem, key_pem)
|
||||
|
||||
try:
|
||||
new_info = parse_cert(new_cert_pem)
|
||||
logger.info(f'Candlepin cert renewed: old serial={info["serial"]}, new serial={new_info["serial"]}, new expiry={new_info["not_after"]}')
|
||||
except ValueError:
|
||||
logger.warning('Candlepin lifecycle: could not parse renewed cert for logging')
|
||||
|
||||
return new_cert_pem, new_key_pem
|
||||
|
||||
logger.info(f'Candlepin cert is healthy ({info["days_remaining"]} days remaining); no renewal needed')
|
||||
return cert_pem, key_pem
|
||||
|
||||
|
||||
def get_candlepin_url():
|
||||
"""Get Candlepin base URL from Django settings."""
|
||||
return settings.AWX_ANALYTICS_CANDLEPIN_URL
|
||||
|
||||
|
||||
def get_renewal_days():
|
||||
"""Get certificate renewal threshold in days from Django settings."""
|
||||
return settings.AWX_ANALYTICS_CANDLEPIN_RENEWAL_THRESHOLD_DAYS
|
||||
|
||||
|
||||
def get_candlepin_ca():
|
||||
"""Get Candlepin CA certificate path from Django settings.
|
||||
|
||||
Returns:
|
||||
str: Path to CA certificate file if configured and exists, None otherwise.
|
||||
"""
|
||||
ca_path = settings.AWX_ANALYTICS_CANDLEPIN_CA
|
||||
if ca_path and not os.path.isfile(ca_path):
|
||||
logger.warning(f'Configured Candlepin CA certificate not found at {ca_path}, using system default CA bundle')
|
||||
return None
|
||||
return ca_path
|
||||
|
||||
|
||||
def get_proxy_url():
|
||||
"""Get proxy URL from Django settings."""
|
||||
return settings.AWX_ANALYTICS_CANDLEPIN_PROXY_URL
|
||||
@@ -93,6 +93,7 @@ __all__ = [
|
||||
'get_event_partition_epoch',
|
||||
'cleanup_new_process',
|
||||
'unified_job_class_to_event_table_name',
|
||||
'get_job_variable_prefixes',
|
||||
]
|
||||
|
||||
|
||||
@@ -150,14 +151,6 @@ def is_testing(argv=None):
|
||||
return False
|
||||
|
||||
|
||||
def bypass_in_test(func):
|
||||
def fn(*args, **kwargs):
|
||||
if not is_testing():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
class RequireDebugTrueOrTest(logging.Filter):
|
||||
"""
|
||||
Logging filter to output when in DEBUG mode or running tests.
|
||||
@@ -773,6 +766,21 @@ def get_cpu_effective_capacity(cpu_count, is_control_node=False):
|
||||
return max(1, int(cpu_count * forkcpu))
|
||||
|
||||
|
||||
def get_job_variable_prefixes():
|
||||
"""Return the list of active job variable prefixes based on INCLUDE_DEPRECATED_AWX_VAR_PREFIX setting.
|
||||
|
||||
When True (default), returns both 'awx' and 'tower' prefixes for backward compatibility.
|
||||
When False, returns only 'tower'. The 'awx' prefix is deprecated and this setting
|
||||
will default to False in a future release.
|
||||
"""
|
||||
from django.conf import settings
|
||||
|
||||
include_awx = getattr(settings, 'INCLUDE_DEPRECATED_AWX_VAR_PREFIX', True)
|
||||
if include_awx:
|
||||
return ['awx', 'tower']
|
||||
return ['tower']
|
||||
|
||||
|
||||
def convert_mem_str_to_bytes(mem_str):
|
||||
"""Convert string with suffix indicating units to memory in bytes (base 2)
|
||||
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
# Copyright (c) 2017 Ansible by Red Hat
|
||||
# All Rights Reserved.
|
||||
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from awx.settings.application_name import set_application_name
|
||||
from awx import MODE
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
|
||||
|
||||
def set_connection_name(function):
|
||||
@@ -32,3 +37,25 @@ def bulk_update_sorted_by_id(model, objects, fields, batch_size=1000):
|
||||
|
||||
sorted_objects = sorted(objects, key=lambda obj: obj.id)
|
||||
return model.objects.bulk_update(sorted_objects, fields, batch_size=batch_size)
|
||||
|
||||
|
||||
MIN_PG_VERSION = 12
|
||||
|
||||
|
||||
def db_requirement_violations() -> Optional[str]:
|
||||
if os.getenv('SKIP_PG_VERSION_CHECK', False):
|
||||
return None
|
||||
if connection.vendor == 'postgresql':
|
||||
|
||||
# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
|
||||
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
|
||||
# The return of connection.pg_version is something like 12013
|
||||
major_version = connection.pg_version // 10000
|
||||
if major_version < MIN_PG_VERSION:
|
||||
return f"At a minimum, postgres version {MIN_PG_VERSION} is required, found {major_version}\n"
|
||||
|
||||
return None
|
||||
else:
|
||||
if MODE == 'production':
|
||||
return f"Running server with '{connection.vendor}' type database is not supported\n"
|
||||
return None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user