diff --git a/awx/api/views.py b/awx/api/views.py index a7bc27a927..cc8288a1f7 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -2443,7 +2443,7 @@ class InventoryScriptView(RetrieveAPIView): for host in obj.hosts.filter(**hosts_q): data['_meta']['hostvars'][host.name] = host.variables_dict if towervars: - tower_dict = dict(remote_tower_enabled=host.enabled, + tower_dict = dict(remote_tower_enabled=str(host.enabled).lower(), remote_tower_id=host.id) data['_meta']['hostvars'][host.name].update(tower_dict) diff --git a/awx/main/managers.py b/awx/main/managers.py index 4825b33231..69d4a24b6b 100644 --- a/awx/main/managers.py +++ b/awx/main/managers.py @@ -7,7 +7,7 @@ import logging from django.db import models from django.utils.timezone import now -from django.db.models import Sum +from django.db.models import Sum, Q from django.conf import settings from awx.main.utils.filters import SmartFilter @@ -21,9 +21,9 @@ class HostManager(models.Manager): """Custom manager class for Hosts model.""" def active_count(self): - """Return count of active, unique hosts for licensing.""" + """Return count of active, unique hosts for licensing. Exclude ones source from another Tower""" try: - return self.order_by('name').distinct('name').count() + return self.filter(~Q(inventory_sources__source='tower')).order_by('name').distinct('name').count() except NotImplementedError: # For unit tests only, SQLite doesn't support distinct('name') return len(set(self.values_list('name', flat=True))) diff --git a/awx/main/tasks.py b/awx/main/tasks.py index c3adb98a07..e284aee63d 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -1914,6 +1914,7 @@ class RunInventoryUpdate(BaseTask): env[str(env_k)] = unicode(inventory_update.source_vars_dict[env_k]) elif inventory_update.source == 'tower': env['TOWER_INVENTORY'] = inventory_update.instance_filters + env['TOWER_LICENSE_TYPE'] = get_licenser().validate()['license_type'] elif inventory_update.source == 'file': raise NotImplementedError('Cannot update file sources through the task system.') # add private_data_files diff --git a/awx/main/tests/functional/test_credential.py b/awx/main/tests/functional/test_credential.py index 9bcf23e198..d69ee9ce37 100644 --- a/awx/main/tests/functional/test_credential.py +++ b/awx/main/tests/functional/test_credential.py @@ -29,6 +29,7 @@ def test_default_cred_types(): 'satellite6', 'scm', 'ssh', + 'tower', 'vault', 'vmware', ] diff --git a/awx/plugins/inventory/tower.py b/awx/plugins/inventory/tower.py index 0fae0865b2..1de920728f 100755 --- a/awx/plugins/inventory/tower.py +++ b/awx/plugins/inventory/tower.py @@ -56,6 +56,7 @@ def parse_configuration(): password = os.environ.get("TOWER_PASSWORD", None) ignore_ssl = os.environ.get("TOWER_IGNORE_SSL", "1").lower() in ("1", "yes", "true") inventory = os.environ.get("TOWER_INVENTORY", None) + license_type = os.environ.get("TOWER_LICENSE_TYPE", "enterprise") errors = [] if not host_name: @@ -74,14 +75,30 @@ def parse_configuration(): tower_user=username, tower_pass=password, tower_inventory=inventory, + tower_license_type=license_type, ignore_ssl=ignore_ssl) -def read_tower_inventory(tower_host, tower_user, tower_pass, inventory, ignore_ssl=False): +def read_tower_inventory(tower_host, tower_user, tower_pass, inventory, license_type, ignore_ssl=False): if not re.match('(?:http|https)://', tower_host): tower_host = "https://{}".format(tower_host) - inventory_url = urljoin(tower_host, "/api/v2/inventories/{}/script/?hostvars=1&towervars=1".format(inventory)) + inventory_url = urljoin(tower_host, "/api/v2/inventories/{}/script/?hostvars=1&towervars=1&all=1".format(inventory.replace('/', ''))) + config_url = urljoin(tower_host, "/api/v2/config/") try: + if license_type != "open": + config_response = requests.get(config_url, + auth=HTTPBasicAuth(tower_user, tower_pass), + verify=not ignore_ssl) + if config_response.ok: + source_type = config_response.json()['license_info']['license_type'] + if not source_type == license_type: + print("Tower server licenses must match: source: {} local: {}".format(source_type, + license_type)) + sys.exit(1) + else: + print("Failed to validate the license of the remote Tower: {}".format(config_response.data)) + sys.exit(1) + response = requests.get(inventory_url, auth=HTTPBasicAuth(tower_user, tower_pass), verify=not ignore_ssl) @@ -101,6 +118,7 @@ def main(): config['tower_user'], config['tower_pass'], config['tower_inventory'], + config['tower_license_type'], ignore_ssl=config['ignore_ssl']) print( json.dumps(