Adding license checks for Tower inventory source

* For Tower the license must match between the source and destination
* For AWX the check is disabled
* Hosts imported from another Tower don't count against your license
  in the local Tower
* Fix up some issues with enablement
* Prevent slashes from being used in the instance filter
* Add &all=1 filter to make sure we pick up all hosts
This commit is contained in:
Matthew Jones 2017-10-26 11:32:16 -04:00
parent d282966aa1
commit 5f3ebc26e0
No known key found for this signature in database
GPG Key ID: 76A4C17A97590C1C
5 changed files with 26 additions and 6 deletions

View File

@ -2443,7 +2443,7 @@ class InventoryScriptView(RetrieveAPIView):
for host in obj.hosts.filter(**hosts_q):
data['_meta']['hostvars'][host.name] = host.variables_dict
if towervars:
tower_dict = dict(remote_tower_enabled=host.enabled,
tower_dict = dict(remote_tower_enabled=str(host.enabled).lower(),
remote_tower_id=host.id)
data['_meta']['hostvars'][host.name].update(tower_dict)

View File

@ -7,7 +7,7 @@ import logging
from django.db import models
from django.utils.timezone import now
from django.db.models import Sum
from django.db.models import Sum, Q
from django.conf import settings
from awx.main.utils.filters import SmartFilter
@ -21,9 +21,9 @@ class HostManager(models.Manager):
"""Custom manager class for Hosts model."""
def active_count(self):
"""Return count of active, unique hosts for licensing."""
"""Return count of active, unique hosts for licensing. Exclude ones source from another Tower"""
try:
return self.order_by('name').distinct('name').count()
return self.filter(~Q(inventory_sources__source='tower')).order_by('name').distinct('name').count()
except NotImplementedError: # For unit tests only, SQLite doesn't support distinct('name')
return len(set(self.values_list('name', flat=True)))

View File

@ -1914,6 +1914,7 @@ class RunInventoryUpdate(BaseTask):
env[str(env_k)] = unicode(inventory_update.source_vars_dict[env_k])
elif inventory_update.source == 'tower':
env['TOWER_INVENTORY'] = inventory_update.instance_filters
env['TOWER_LICENSE_TYPE'] = get_licenser().validate()['license_type']
elif inventory_update.source == 'file':
raise NotImplementedError('Cannot update file sources through the task system.')
# add private_data_files

View File

@ -29,6 +29,7 @@ def test_default_cred_types():
'satellite6',
'scm',
'ssh',
'tower',
'vault',
'vmware',
]

View File

@ -56,6 +56,7 @@ def parse_configuration():
password = os.environ.get("TOWER_PASSWORD", None)
ignore_ssl = os.environ.get("TOWER_IGNORE_SSL", "1").lower() in ("1", "yes", "true")
inventory = os.environ.get("TOWER_INVENTORY", None)
license_type = os.environ.get("TOWER_LICENSE_TYPE", "enterprise")
errors = []
if not host_name:
@ -74,14 +75,30 @@ def parse_configuration():
tower_user=username,
tower_pass=password,
tower_inventory=inventory,
tower_license_type=license_type,
ignore_ssl=ignore_ssl)
def read_tower_inventory(tower_host, tower_user, tower_pass, inventory, ignore_ssl=False):
def read_tower_inventory(tower_host, tower_user, tower_pass, inventory, license_type, ignore_ssl=False):
if not re.match('(?:http|https)://', tower_host):
tower_host = "https://{}".format(tower_host)
inventory_url = urljoin(tower_host, "/api/v2/inventories/{}/script/?hostvars=1&towervars=1".format(inventory))
inventory_url = urljoin(tower_host, "/api/v2/inventories/{}/script/?hostvars=1&towervars=1&all=1".format(inventory.replace('/', '')))
config_url = urljoin(tower_host, "/api/v2/config/")
try:
if license_type != "open":
config_response = requests.get(config_url,
auth=HTTPBasicAuth(tower_user, tower_pass),
verify=not ignore_ssl)
if config_response.ok:
source_type = config_response.json()['license_info']['license_type']
if not source_type == license_type:
print("Tower server licenses must match: source: {} local: {}".format(source_type,
license_type))
sys.exit(1)
else:
print("Failed to validate the license of the remote Tower: {}".format(config_response.data))
sys.exit(1)
response = requests.get(inventory_url,
auth=HTTPBasicAuth(tower_user, tower_pass),
verify=not ignore_ssl)
@ -101,6 +118,7 @@ def main():
config['tower_user'],
config['tower_pass'],
config['tower_inventory'],
config['tower_license_type'],
ignore_ssl=config['ignore_ssl'])
print(
json.dumps(