mirror of
https://github.com/ansible/awx.git
synced 2026-04-05 01:59:25 -02:30
update trial license enforcement logic
This commit is contained in:
@@ -17,6 +17,8 @@ from rest_framework.permissions import AllowAny, IsAuthenticated
|
|||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from awx.api.generics import APIView
|
from awx.api.generics import APIView
|
||||||
from awx.main.ha import is_ha_environment
|
from awx.main.ha import is_ha_environment
|
||||||
from awx.main.utils import (
|
from awx.main.utils import (
|
||||||
@@ -248,14 +250,39 @@ class ApiV2ConfigView(APIView):
|
|||||||
logger.info(smart_text(u"Invalid JSON submitted for license."),
|
logger.info(smart_text(u"Invalid JSON submitted for license."),
|
||||||
extra=dict(actor=request.user.username))
|
extra=dict(actor=request.user.username))
|
||||||
return Response({"error": _("Invalid JSON")}, status=status.HTTP_400_BAD_REQUEST)
|
return Response({"error": _("Invalid JSON")}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from awx.main.utils.common import get_licenser
|
from awx.main.utils.common import get_licenser
|
||||||
license_data = json.loads(data_actual)
|
license_data = json.loads(data_actual)
|
||||||
|
if license_data.get('rh_password') == '$encrypted$':
|
||||||
|
license_data['rh_password'] = settings.REDHAT_PASSWORD
|
||||||
license_data_validated = get_licenser(**license_data).validate()
|
license_data_validated = get_licenser(**license_data).validate()
|
||||||
except Exception:
|
if license_data_validated.get('valid_key') and 'license_key' not in license_data:
|
||||||
logger.warning(smart_text(u"Invalid license submitted."),
|
if license_data.get('rh_username') and license_data.get('rh_password'):
|
||||||
extra=dict(actor=request.user.username))
|
settings.REDHAT_USERNAME = license_data['rh_username']
|
||||||
return Response({"error": _("Invalid License")}, status=status.HTTP_400_BAD_REQUEST)
|
settings.REDHAT_PASSWORD = license_data['rh_password']
|
||||||
|
license_data = {
|
||||||
|
"eula_accepted": eula_accepted,
|
||||||
|
"features": license_data_validated['features'],
|
||||||
|
"license_type": license_data_validated['license_type'],
|
||||||
|
"license_date": license_data_validated['license_date'],
|
||||||
|
"license_key": license_data_validated['license_key'],
|
||||||
|
"instance_count": license_data_validated['instance_count'],
|
||||||
|
}
|
||||||
|
if license_data_validated.get('trial'):
|
||||||
|
license_data['trial'] = True
|
||||||
|
except Exception as exc:
|
||||||
|
msg = _("Invalid License")
|
||||||
|
if (
|
||||||
|
isinstance(exc, requests.exceptions.HTTPError) and
|
||||||
|
getattr(getattr(exc, 'response', None), 'status_code', None) == 401
|
||||||
|
):
|
||||||
|
msg = _("The provided credentials are invalid (HTTP 401).")
|
||||||
|
if isinstance(exc, ValueError) and exc.args:
|
||||||
|
msg = exc.args[0]
|
||||||
|
logger.exception(smart_text(u"Invalid license submitted."),
|
||||||
|
extra=dict(actor=request.user.username))
|
||||||
|
return Response({"error": msg}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
# If the license is valid, write it to the database.
|
# If the license is valid, write it to the database.
|
||||||
if license_data_validated['valid_key']:
|
if license_data_validated['valid_key']:
|
||||||
|
|||||||
@@ -317,10 +317,19 @@ class BaseAccess(object):
|
|||||||
validation_info['time_remaining'] = 99999999
|
validation_info['time_remaining'] = 99999999
|
||||||
validation_info['grace_period_remaining'] = 99999999
|
validation_info['grace_period_remaining'] = 99999999
|
||||||
|
|
||||||
|
report_violation = lambda message: logger.error(message)
|
||||||
|
|
||||||
|
if (
|
||||||
|
validation_info.get('trial', False) is True or
|
||||||
|
validation_info['instance_count'] == 10 # basic 10 license
|
||||||
|
):
|
||||||
|
def report_violation(message):
|
||||||
|
raise PermissionDenied(message)
|
||||||
|
|
||||||
if check_expiration and validation_info.get('time_remaining', None) is None:
|
if check_expiration and validation_info.get('time_remaining', None) is None:
|
||||||
raise PermissionDenied(_("License is missing."))
|
raise PermissionDenied(_("License is missing."))
|
||||||
if check_expiration and validation_info.get("grace_period_remaining") <= 0:
|
elif check_expiration and validation_info.get("grace_period_remaining") <= 0:
|
||||||
raise PermissionDenied(_("License has expired."))
|
report_violation(_("License has expired."))
|
||||||
|
|
||||||
free_instances = validation_info.get('free_instances', 0)
|
free_instances = validation_info.get('free_instances', 0)
|
||||||
available_instances = validation_info.get('available_instances', 0)
|
available_instances = validation_info.get('available_instances', 0)
|
||||||
@@ -328,11 +337,11 @@ class BaseAccess(object):
|
|||||||
if add_host_name:
|
if add_host_name:
|
||||||
host_exists = Host.objects.filter(name=add_host_name).exists()
|
host_exists = Host.objects.filter(name=add_host_name).exists()
|
||||||
if not host_exists and free_instances == 0:
|
if not host_exists and free_instances == 0:
|
||||||
raise PermissionDenied(_("License count of %s instances has been reached.") % available_instances)
|
report_violation(_("License count of %s instances has been reached.") % available_instances)
|
||||||
elif not host_exists and free_instances < 0:
|
elif not host_exists and free_instances < 0:
|
||||||
raise PermissionDenied(_("License count of %s instances has been exceeded.") % available_instances)
|
report_violation(_("License count of %s instances has been exceeded.") % available_instances)
|
||||||
elif not add_host_name and free_instances < 0:
|
elif not add_host_name and free_instances < 0:
|
||||||
raise PermissionDenied(_("Host count exceeds available instances."))
|
report_violation(_("Host count exceeds available instances."))
|
||||||
|
|
||||||
def check_org_host_limit(self, data, add_host_name=None):
|
def check_org_host_limit(self, data, add_host_name=None):
|
||||||
validation_info = get_licenser().validate()
|
validation_info = get_licenser().validate()
|
||||||
|
|||||||
@@ -919,7 +919,8 @@ class Command(BaseCommand):
|
|||||||
new_count = Host.objects.active_count()
|
new_count = Host.objects.active_count()
|
||||||
if time_remaining <= 0 and not license_info.get('demo', False):
|
if time_remaining <= 0 and not license_info.get('demo', False):
|
||||||
logger.error(LICENSE_EXPIRED_MESSAGE)
|
logger.error(LICENSE_EXPIRED_MESSAGE)
|
||||||
raise CommandError("License has expired!")
|
if license_info.get('trial', False) is True:
|
||||||
|
raise CommandError("License has expired!")
|
||||||
# special check for tower-type inventory sources
|
# special check for tower-type inventory sources
|
||||||
# but only if running the plugin
|
# but only if running the plugin
|
||||||
TOWER_SOURCE_FILES = ['tower.yml', 'tower.yaml']
|
TOWER_SOURCE_FILES = ['tower.yml', 'tower.yaml']
|
||||||
@@ -936,7 +937,11 @@ class Command(BaseCommand):
|
|||||||
logger.error(DEMO_LICENSE_MESSAGE % d)
|
logger.error(DEMO_LICENSE_MESSAGE % d)
|
||||||
else:
|
else:
|
||||||
logger.error(LICENSE_MESSAGE % d)
|
logger.error(LICENSE_MESSAGE % d)
|
||||||
raise CommandError('License count exceeded!')
|
if (
|
||||||
|
license_info.get('trial', False) is True or
|
||||||
|
license_info['instance_count'] == 10 # basic 10 license
|
||||||
|
):
|
||||||
|
raise CommandError('License count exceeded!')
|
||||||
|
|
||||||
def check_org_host_limit(self):
|
def check_org_host_limit(self):
|
||||||
license_info = get_licenser().validate()
|
license_info = get_licenser().validate()
|
||||||
|
|||||||
Reference in New Issue
Block a user