reimplement licensing to work with RHSM and entitlement uploads

Co-authored-by: Christian Adams <chadams@redhat.com>
This commit is contained in:
Ryan Petrello
2020-10-22 10:56:26 -04:00
parent a26d20cbf2
commit ffab48c77f
42 changed files with 1117 additions and 174 deletions

View File

@@ -307,7 +307,7 @@ class BaseAccess(object):
return True # User has access to both, permission check passed
def check_license(self, add_host_name=None, feature=None, check_expiration=True, quiet=False):
validation_info = get_licenser().validate()
validation_info = get_licenser().validate(new_cert=False)
if validation_info.get('license_type', 'UNLICENSED') == 'open':
return
@@ -345,7 +345,7 @@ class BaseAccess(object):
report_violation(_("Host count exceeds available instances."))
def check_org_host_limit(self, data, add_host_name=None):
validation_info = get_licenser().validate()
validation_info = get_licenser().validate(new_cert=False)
if validation_info.get('license_type', 'UNLICENSED') == 'open':
return

View File

@@ -33,9 +33,9 @@ data _since_ the last report date - i.e., new data in the last 24 hours)
'''
@register('config', '1.1', description=_('General platform configuration.'))
@register('config', '1.2', description=_('General platform configuration.'))
def config(since, **kwargs):
license_info = get_license(show_key=False)
license_info = get_license()
install_type = 'traditional'
if os.environ.get('container') == 'oci':
install_type = 'openshift'

View File

@@ -24,7 +24,7 @@ logger = logging.getLogger('awx.main.analytics')
def _valid_license():
try:
if get_license(show_key=False).get('license_type', 'UNLICENSED') == 'open':
if get_license().get('license_type', 'UNLICENSED') == 'open':
return False
access_registry[Job](None).check_license()
except PermissionDenied:

View File

@@ -54,7 +54,7 @@ LICENSE_INSTANCE_FREE = Gauge('awx_license_instance_free', 'Number of remaining
def metrics():
license_info = get_license(show_key=False)
license_info = get_license()
SYSTEM_INFO.info({
'install_uuid': settings.INSTALL_UUID,
'insights_analytics': str(settings.INSIGHTS_TRACKING_STATE),

View File

@@ -12,6 +12,8 @@ from rest_framework.fields import FloatField
# Tower
from awx.conf import fields, register, register_validate
from awx.main.validators import validate_entitlement_cert
logger = logging.getLogger('awx.main.conf')
@@ -142,6 +144,32 @@ register(
category_slug='system',
)
register(
'SUBSCRIPTIONS_USERNAME',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=False,
read_only=False,
label=_('Red Hat or Satellite username'),
help_text=_('This username is used to retrieve subscription and content information'), # noqa
category=_('System'),
category_slug='system',
)
register(
'SUBSCRIPTIONS_PASSWORD',
field_class=fields.CharField,
default='',
allow_blank=True,
encrypted=True,
read_only=False,
label=_('Red Hat or Satellite password'),
help_text=_('This password is used to retrieve subscription and content information'), # noqa
category=_('System'),
category_slug='system',
)
register(
'AUTOMATION_ANALYTICS_URL',
field_class=fields.URLField,
@@ -328,6 +356,21 @@ register(
category_slug='jobs',
)
register(
'ENTITLEMENT_CERT',
field_class=fields.CharField,
allow_blank=True,
default='',
required=False,
validators=[validate_entitlement_cert], # TODO: may need to use/modify `validate_certificate` validator
label=_('RHSM Entitlement Public Certificate and Private Key'),
help_text=_('Obtain a key pair via subscription-manager, or https://access.redhat.com. Refer to Ansible Tower docs for formatting key pair.'),
category=_('SYSTEM'),
category_slug='system',
encrypted=True,
)
register(
'AWX_RESOURCE_PROFILING_ENABLED',
field_class=fields.BooleanField,

View File

@@ -16,9 +16,7 @@ class Command(BaseCommand):
def handle(self, *args, **options):
super(Command, self).__init__()
license = get_licenser().validate()
license = get_licenser().validate(new_cert=False)
if options.get('data'):
if license.get('license_key', '') != 'UNLICENSED':
license['license_key'] = '********'
return json.dumps(license)
return license.get('license_type', 'none')

View File

@@ -901,7 +901,7 @@ class Command(BaseCommand):
))
def check_license(self):
license_info = get_licenser().validate()
license_info = get_licenser().validate(new_cert=False)
local_license_type = license_info.get('license_type', 'UNLICENSED')
if license_info.get('license_key', 'UNLICENSED') == 'UNLICENSED':
logger.error(LICENSE_NON_EXISTANT_MESSAGE)
@@ -938,7 +938,7 @@ class Command(BaseCommand):
logger.warning(LICENSE_MESSAGE % d)
def check_org_host_limit(self):
license_info = get_licenser().validate()
license_info = get_licenser().validate(new_cert=False)
if license_info.get('license_type', 'UNLICENSED') == 'open':
return

View File

@@ -2160,7 +2160,7 @@ class RunProjectUpdate(BaseTask):
'local_path': os.path.basename(project_update.project.local_path),
'project_path': project_update.get_project_path(check_if_exists=False), # deprecated
'insights_url': settings.INSIGHTS_URL_BASE,
'awx_license_type': get_license(show_key=False).get('license_type', 'UNLICENSED'),
'awx_license_type': get_license().get('license_type', 'UNLICENSED'),
'awx_version': get_awx_version(),
'scm_url': scm_url,
'scm_branch': scm_branch,

View File

@@ -1,13 +1,16 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved.
from awx.main.utils.common import StubLicense
# from awx.main.utils.common import StubLicense
#
#
# def test_stub_license():
# license_actual = StubLicense().validate()
# assert license_actual['license_key'] == 'OPEN'
# assert license_actual['valid_key']
# assert license_actual['compliant']
# assert license_actual['license_type'] == 'open'
def test_stub_license():
license_actual = StubLicense().validate()
assert license_actual['license_key'] == 'OPEN'
assert license_actual['valid_key']
assert license_actual['compliant']
assert license_actual['license_type'] == 'open'
# Test license_date is always seconds

View File

@@ -30,8 +30,7 @@ def test_python_and_js_licenses():
# Check variations of '-' and '_' in filenames due to python
for fname in [name, name.replace('-','_')]:
if entry.startswith(fname) and entry.endswith('.tar.gz'):
entry = entry[:-7]
(n, v) = entry.rsplit('-',1)
v = entry.split(name + '-')[1].split('.tar.gz')[0]
return v
return None

View File

@@ -39,6 +39,8 @@ from awx.main import tasks
from awx.main.utils import encrypt_field, encrypt_value
from awx.main.utils.safe_yaml import SafeLoader
from awx.main.utils.licensing import Licenser
class TestJobExecution(object):
EXAMPLE_PRIVATE_KEY = '-----BEGIN PRIVATE KEY-----\nxyz==\n-----END PRIVATE KEY-----'
@@ -1830,7 +1832,10 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution):
task = RunProjectUpdate()
env = task.build_env(project_update, private_data_dir)
task.build_extra_vars_file(project_update, private_data_dir)
with mock.patch.object(Licenser, 'validate', lambda *args, **kw: {}):
task.build_extra_vars_file(project_update, private_data_dir)
assert task.__vars__['roles_enabled'] is False
assert task.__vars__['collections_enabled'] is False
for k in env:
@@ -1850,7 +1855,10 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution):
project_update.project.organization.galaxy_credentials.add(public_galaxy)
task = RunProjectUpdate()
env = task.build_env(project_update, private_data_dir)
task.build_extra_vars_file(project_update, private_data_dir)
with mock.patch.object(Licenser, 'validate', lambda *args, **kw: {}):
task.build_extra_vars_file(project_update, private_data_dir)
assert task.__vars__['roles_enabled'] is True
assert task.__vars__['collections_enabled'] is True
assert sorted([
@@ -1935,7 +1943,9 @@ class TestProjectUpdateCredentials(TestJobExecution):
assert settings.PROJECTS_ROOT in process_isolation['process_isolation_show_paths']
task._write_extra_vars_file = mock.Mock()
task.build_extra_vars_file(project_update, private_data_dir)
with mock.patch.object(Licenser, 'validate', lambda *args, **kw: {}):
task.build_extra_vars_file(project_update, private_data_dir)
call_args, _ = task._write_extra_vars_file.call_args_list[0]
_, extra_vars = call_args

View File

@@ -55,8 +55,7 @@ __all__ = [
'model_instance_diff', 'parse_yaml_or_json', 'RequireDebugTrueOrTest',
'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',
'get_custom_venv_choices', 'get_external_account', 'task_manager_bulk_reschedule',
'schedule_task_manager', 'classproperty', 'create_temporary_fifo', 'truncate_stdout',
'StubLicense'
'schedule_task_manager', 'classproperty', 'create_temporary_fifo', 'truncate_stdout'
]
@@ -190,7 +189,7 @@ def get_awx_version():
def get_awx_http_client_headers():
license = get_license(show_key=False).get('license_type', 'UNLICENSED')
license = get_license().get('license_type', 'UNLICENSED')
headers = {
'Content-Type': 'application/json',
'User-Agent': '{} {} ({})'.format(
@@ -202,34 +201,14 @@ def get_awx_http_client_headers():
return headers
class StubLicense(object):
features = {
'activity_streams': True,
'ha': True,
'ldap': True,
'multiple_organizations': True,
'surveys': True,
'system_tracking': True,
'rebranding': True,
'enterprise_auth': True,
'workflows': True,
}
def validate(self):
return dict(license_key='OPEN',
valid_key=True,
compliant=True,
features=self.features,
license_type='open')
# Update all references of this function in the codebase to just import `from awx.main.utils.licensing import Licenser` directly
def get_licenser(*args, **kwargs):
try:
from tower_license import TowerLicense
return TowerLicense(*args, **kwargs)
except ImportError:
return StubLicense(*args, **kwargs)
# Get License Config from db?
from awx.main.utils.licensing import Licenser
return Licenser(*args, **kwargs)
except Exception as e:
raise ValueError(_('Error importing Tower License: %s') % e)
def update_scm_url(scm_type, url, username=True, password=True,

382
awx/main/utils/licensing.py Normal file
View File

@@ -0,0 +1,382 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved.
'''
This is intended to be a lightweight license class for verifying subscriptions, and parsing subscription data
from entitlement certificates.
The Licenser class can do the following:
- Parse an Entitlement cert to generate license
'''
import os
import configparser
from datetime import datetime
import collections
import copy
import tempfile
import logging
import re
import requests
import time
# Django
from django.conf import settings
# AWX
from awx.main.models import Host
# RHSM
from rhsm import certificate
MAX_INSTANCES = 9999999
logger = logging.getLogger(__name__)
def rhsm_config():
config = configparser.ConfigParser()
config.read('/etc/rhsm/rhsm.conf')
return config
class Licenser(object):
# warn when there is a month (30 days) left on the license
LICENSE_TIMEOUT = 60 * 60 * 24 * 30
UNLICENSED_DATA = dict(
subscription_name=None,
sku=None,
support_level=None,
instance_count=0,
license_date=0,
license_type="UNLICENSED",
product_name="Red Hat Ansible Tower",
valid_key=False
)
def __init__(self, **kwargs):
self._attrs = dict(
instance_count=0,
license_date=0,
license_type='UNLICENSED',
)
self.config = rhsm_config()
if not kwargs:
license_setting = getattr(settings, 'LICENSE', None)
if license_setting is not None:
kwargs = license_setting
if 'company_name' in kwargs:
kwargs.pop('company_name')
self._attrs.update(kwargs)
if self._check_product_cert():
if 'license_key' in self._attrs:
self._unset_attrs()
if 'valid_key' in self._attrs:
if not self._attrs['valid_key']:
self._unset_attrs()
else:
self._unset_attrs()
else:
self._generate_open_config()
def _check_product_cert(self):
# Product Cert Name: ansible-tower-3.7-rhel-7.x86_64.pem
# Maybe check validity of Product Cert somehow?
if os.path.exists('/etc/tower/certs') and os.path.exists('/var/lib/awx/.tower_version'):
return True
return False
def _generate_open_config(self):
self._attrs.update(dict(license_type='open',
valid_key=True,
subscription_name='OPEN',
product_name="AWX",
))
settings.LICENSE = self._attrs
def _unset_attrs(self):
self._attrs = self.UNLICENSED_DATA.copy()
def _clear_license_setting(self):
self.unset_attrs()
settings.LICENSE = {}
def _generate_product_config(self):
raw_cert = getattr(settings, 'ENTITLEMENT_CERT', None)
# Fail early if no entitlement cert is available
if not raw_cert or raw_cert == '':
self._clear_license_setting()
return
# Read certificate
certinfo = certificate.create_from_pem(raw_cert)
if not certinfo.is_valid():
raise ValueError("Could not parse entitlement certificate")
if certinfo.is_expired():
raise ValueError("Certificate is expired")
if not any(map(lambda x: x.id == '480', certinfo.products)):
self._clear_license_setting()
raise ValueError("Certificate is for another product")
# Parse output for subscription metadata to build config
license = dict()
license['sku'] = certinfo.order.sku
license['instance_count'] = int(certinfo.order.quantity_used)
license['support_level'] = certinfo.order.service_level
license['subscription_name'] = certinfo.order.name
license['pool_id'] = certinfo.pool.id
license['license_date'] = certinfo.end.strftime('%s')
license['product_name'] = certinfo.order.name
license['valid_key'] = True
license['license_type'] = 'enterprise'
license['satellite'] = False
self._attrs.update(license)
settings.LICENSE = self._attrs
return self._attrs
def update(self, **kwargs):
# Update attributes of the current license.
if 'instance_count' in kwargs:
kwargs['instance_count'] = int(kwargs['instance_count'])
if 'license_date' in kwargs:
kwargs['license_date'] = int(kwargs['license_date'])
self._attrs.update(kwargs)
def validate_rh(self, user, pw):
host = 'https://' + str(self.config.get("server", "hostname"))
if not host:
host = getattr(settings, 'REDHAT_CANDLEPIN_HOST', None)
if not user:
raise ValueError('subscriptions_username is required')
if not pw:
raise ValueError('subscriptions_password is required')
if host and user and pw:
if 'subscription.rhsm.redhat.com' in host:
json = self.get_rhsm_subs(host, user, pw)
else:
json = self.get_satellite_subs(host, user, pw)
return self.generate_license_options_from_entitlements(json)
return []
def get_rhsm_subs(self, host, user, pw):
verify = getattr(settings, 'REDHAT_CANDLEPIN_VERIFY', False)
json = []
try:
subs = requests.get(
'/'.join([host, 'subscription/users/{}/owners'.format(user)]),
verify=verify,
auth=(user, pw)
)
except requests.exceptions.ConnectionError as error:
raise error
except OSError as error:
raise OSError('Unable to open certificate bundle {}. Check that Ansible Tower is running on Red Hat Enterprise Linux.'.format(verify)) from error # noqa
subs.raise_for_status()
for sub in subs.json():
resp = requests.get(
'/'.join([
host,
'subscription/owners/{}/pools/?match=*tower*'.format(sub['key'])
]),
verify=verify,
auth=(user, pw)
)
resp.raise_for_status()
json.extend(resp.json())
return json
def get_satellite_subs(self, host, user, pw):
try:
verify = str(self.config.get("rhsm", "repo_ca_cert"))
except Exception as e:
raise OSError('Unable to read rhsm config to get ca_cert location. {}'.format(str(e)))
json = []
try:
orgs = requests.get(
'/'.join([host, 'katello/api/organizations']),
verify=verify,
auth=(user, pw)
)
except requests.exceptions.ConnectionError as error:
raise error
except OSError as error:
raise OSError('Unable to open certificate bundle {}. Check that Ansible Tower is running on Red Hat Enterprise Linux.'.format(verify)) from error # noqa
orgs.raise_for_status()
for org in orgs.json()['results']:
resp = requests.get(
'/'.join([
host,
'/katello/api/organizations/{}/subscriptions/?search=Red Hat Ansible Automation'.format(org['id'])
]),
verify=verify,
auth=(user, pw)
)
resp.raise_for_status()
results = resp.json()['results']
if results != []:
for sub in results:
# Parse output for subscription metadata to build config
license = dict()
license['productId'] = sub['product_id']
license['quantity'] = int(sub['quantity'])
license['support_level'] = sub['support_level']
license['subscription_name'] = sub['name']
license['id'] = sub['upstream_pool_id']
license['endDate'] = sub['end_date']
license['productName'] = "Red Hat Ansible Automation"
license['valid_key'] = True
license['license_type'] = 'enterprise'
license['satellite'] = True
json.append(license)
return json
def is_appropriate_sat_sub(self, sub):
if 'Red Hat Ansible Automation' not in sub['subscription_name']:
return False
return True
def is_appropriate_sub(self, sub):
if sub['activeSubscription'] is False:
return False
# Products that contain Ansible Tower
products = sub.get('providedProducts', [])
if any(map(lambda product: product.get('productId', None) == "480", products)):
return True
return False
def generate_license_options_from_entitlements(self, json):
from dateutil.parser import parse
ValidSub = collections.namedtuple('ValidSub', 'sku name support_level end_date trial quantity pool_id satellite')
valid_subs = []
for sub in json:
satellite = sub['satellite']
if satellite:
is_valid = self.is_appropriate_sat_sub(sub)
else:
is_valid = self.is_appropriate_sub(sub)
if is_valid:
try:
end_date = parse(sub.get('endDate'))
except Exception:
continue
now = datetime.utcnow()
now = now.replace(tzinfo=end_date.tzinfo)
if end_date < now:
# If the sub has a past end date, skip it
continue
try:
quantity = int(sub['quantity'])
if quantity == -1:
# effectively, unlimited
quantity = MAX_INSTANCES
except Exception:
continue
sku = sub['productId']
trial = sku.startswith('S') # i.e.,, SER/SVC
support_level = ''
pool_id = sub['id']
if satellite:
support_level = sub['support_level']
else:
for attr in sub.get('productAttributes', []):
if attr.get('name') == 'support_level':
support_level = attr.get('value')
valid_subs.append(ValidSub(
sku, sub['productName'], support_level, end_date, trial, quantity, pool_id, satellite
))
if valid_subs:
licenses = []
for sub in valid_subs:
license = self.__class__(subscription_name='Ansible Tower by Red Hat')
license._attrs['instance_count'] = int(sub.quantity)
license._attrs['sku'] = sub.sku
license._attrs['support_level'] = sub.support_level
license._attrs['license_type'] = 'enterprise'
if sub.trial:
license._attrs['trial'] = True
license._attrs['instance_count'] = min(
MAX_INSTANCES, license._attrs['instance_count']
)
human_instances = license._attrs['instance_count']
if human_instances == MAX_INSTANCES:
human_instances = 'Unlimited'
subscription_name = re.sub(
r' \([\d]+ Managed Nodes',
' ({} Managed Nodes'.format(human_instances),
sub.name
)
license._attrs['subscription_name'] = subscription_name
license._attrs['satellite'] = satellite
license.update(
license_date=int(sub.end_date.strftime('%s'))
)
license.update(
pool_id=sub.pool_id
)
licenses.append(license._attrs.copy())
return licenses
raise ValueError(
'No valid Red Hat Ansible Automation subscription could be found for this account.' # noqa
)
def validate(self, new_cert=False):
# Generate Config from Entitlement cert if it exists
if new_cert:
self._generate_product_config()
# Return license attributes with additional validation info.
attrs = copy.deepcopy(self._attrs)
type = attrs.get('license_type', 'none')
if (type == 'UNLICENSED' or False):
attrs.update(dict(valid_key=False, compliant=False))
return attrs
attrs['valid_key'] = True
if Host:
current_instances = Host.objects.active_count()
else:
current_instances = 0
available_instances = int(attrs.get('instance_count', None) or 0)
attrs['current_instances'] = current_instances
attrs['available_instances'] = available_instances
attrs['free_instances'] = max(0, available_instances - current_instances)
license_date = int(attrs.get('license_date', None) or 0)
current_date = int(time.time())
time_remaining = license_date - current_date
attrs['time_remaining'] = time_remaining
if attrs.setdefault('trial', False):
attrs['grace_period_remaining'] = time_remaining
else:
attrs['grace_period_remaining'] = (license_date + 2592000) - current_date
attrs['compliant'] = bool(time_remaining > 0 and attrs['free_instances'] >= 0)
attrs['date_warning'] = bool(time_remaining < self.LICENSE_TIMEOUT)
attrs['date_expired'] = bool(time_remaining <= 0)
return attrs

View File

@@ -17,6 +17,12 @@ from rest_framework.exceptions import ParseError
from awx.main.utils import parse_yaml_or_json
def validate_entitlement_cert(data, min_keys=0, max_keys=None, min_certs=0, max_certs=None):
# TODO: replace with more actual robust logic here to allow for multiple certificates in one cert file (because this is how entitlements do)
# This is a temporary hack that should not be merged. Look at Ln:92
pass
def validate_pem(data, min_keys=0, max_keys=None, min_certs=0, max_certs=None):
"""
Validate the given PEM data is valid and contains the required numbers of