diff --git a/awx/api/filters.py b/awx/api/filters.py index 050534a963..af0060c008 100644 --- a/awx/api/filters.py +++ b/awx/api/filters.py @@ -19,6 +19,11 @@ from rest_framework.filters import BaseFilterBackend # Ansible Tower from awx.main.utils import get_type_for_model, to_python_boolean +class MongoFilterBackend(BaseFilterBackend): + + def filter_queryset(self, request, queryset, view): + return queryset + class ActiveOnlyBackend(BaseFilterBackend): ''' Filter to show only objects where is_active/active is True. @@ -61,7 +66,7 @@ class TypeFilterBackend(BaseFilterBackend): queryset = queryset.filter(polymorphic_ctype_id__in=types_pks) elif model_type in types: queryset = queryset - else: + else: queryset = queryset.none() return queryset except FieldError, e: diff --git a/awx/api/generics.py b/awx/api/generics.py index 1ac01a3eb6..b5e90ee876 100644 --- a/awx/api/generics.py +++ b/awx/api/generics.py @@ -31,7 +31,8 @@ __all__ = ['APIView', 'GenericAPIView', 'ListAPIView', 'SimpleListAPIView', 'ListCreateAPIView', 'SubListAPIView', 'SubListCreateAPIView', 'SubListCreateAttachDetachAPIView', 'RetrieveAPIView', 'RetrieveUpdateAPIView', 'RetrieveDestroyAPIView', - 'RetrieveUpdateDestroyAPIView', 'DestroyAPIView'] + 'RetrieveUpdateDestroyAPIView', 'DestroyAPIView', + 'MongoAPIView', 'MongoListAPIView'] logger = logging.getLogger('awx.api.generics') @@ -164,7 +165,6 @@ class APIView(views.APIView): ret['added_in_version'] = added_in_version return ret - class GenericAPIView(generics.GenericAPIView, APIView): # Base class for all model-based views. @@ -195,11 +195,13 @@ class GenericAPIView(generics.GenericAPIView, APIView): if not hasattr(self, 'format_kwarg'): self.format_kwarg = 'format' d = super(GenericAPIView, self).get_description_context() - d.update({ - 'model_verbose_name': unicode(self.model._meta.verbose_name), - 'model_verbose_name_plural': unicode(self.model._meta.verbose_name_plural), - 'serializer_fields': self.get_serializer().metadata(), - }) + if hasattr(self.model, "_meta"): + if hasattr(self.model._meta, "verbose_name"): + d.update({ + 'model_verbose_name': unicode(self.model._meta.verbose_name), + 'model_verbose_name_plural': unicode(self.model._meta.verbose_name_plural), + }) + d.update({'serializer_fields': self.get_serializer().metadata()}) return d def metadata(self, request): @@ -252,6 +254,27 @@ class GenericAPIView(generics.GenericAPIView, APIView): ret['search_fields'] = self.search_fields return ret +class MongoAPIView(GenericAPIView): + + def get_parent_object(self): + parent_filter = { + self.lookup_field: self.kwargs.get(self.lookup_field, None), + } + return get_object_or_404(self.parent_model, **parent_filter) + + def check_parent_access(self, parent=None): + parent = parent or self.get_parent_object() + parent_access = getattr(self, 'parent_access', 'read') + if parent_access in ('read', 'delete'): + args = (self.parent_model, parent_access, parent) + else: + args = (self.parent_model, parent_access, parent, None) + if not self.request.user.can_access(*args): + raise PermissionDenied() + +class MongoListAPIView(generics.ListAPIView, MongoAPIView): + pass + class SimpleListAPIView(generics.ListAPIView, GenericAPIView): def get_queryset(self): diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 65b573de11..f44361a79e 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -8,6 +8,8 @@ import logging from dateutil import rrule from ast import literal_eval +from rest_framework_mongoengine.serializers import MongoEngineModelSerializer + # PyYAML import yaml @@ -33,9 +35,11 @@ from polymorphic import PolymorphicModel # AWX from awx.main.constants import SCHEDULEABLE_PROVIDERS from awx.main.models import * # noqa -from awx.main.utils import get_type_for_model, get_model_for_type +from awx.main.utils import get_type_for_model, get_model_for_type, build_url, timestamp_apiformat from awx.main.redact import REPLACE_STR +from awx.fact.models import * # noqa + logger = logging.getLogger('awx.api.serializers') # Fields that should be summarized regardless of object type. @@ -774,6 +778,7 @@ class InventorySerializer(BaseSerializerWithVariables): activity_stream = reverse('api:inventory_activity_stream_list', args=(obj.pk,)), scan_job_templates = reverse('api:inventory_scan_job_template_list', args=(obj.pk,)), ad_hoc_commands = reverse('api:inventory_ad_hoc_commands_list', args=(obj.pk,)), + single_fact = reverse('api:inventory_single_fact_view', args=(obj.pk,)), )) if obj.organization and obj.organization.active: res['organization'] = reverse('api:organization_detail', args=(obj.organization.pk,)) @@ -826,6 +831,8 @@ class HostSerializer(BaseSerializerWithVariables): inventory_sources = reverse('api:host_inventory_sources_list', args=(obj.pk,)), ad_hoc_commands = reverse('api:host_ad_hoc_commands_list', args=(obj.pk,)), ad_hoc_command_events = reverse('api:host_ad_hoc_command_events_list', args=(obj.pk,)), + fact_versions = reverse('api:host_fact_versions_list', args=(obj.pk,)), + single_fact = reverse('api:host_single_fact_view', args=(obj.pk,)), )) if obj.inventory and obj.inventory.active: res['inventory'] = reverse('api:inventory_detail', args=(obj.inventory.pk,)) @@ -927,6 +934,7 @@ class GroupSerializer(BaseSerializerWithVariables): activity_stream = reverse('api:group_activity_stream_list', args=(obj.pk,)), inventory_sources = reverse('api:group_inventory_sources_list', args=(obj.pk,)), ad_hoc_commands = reverse('api:group_ad_hoc_commands_list', args=(obj.pk,)), + single_fact = reverse('api:group_single_fact_view', args=(obj.pk,)), )) if obj.inventory and obj.inventory.active: res['inventory'] = reverse('api:inventory_detail', args=(obj.inventory.pk,)) @@ -1537,12 +1545,10 @@ class JobRelaunchSerializer(JobSerializer): obj = self.context.get('obj') if not obj.credential or obj.credential.active is False: raise serializers.ValidationError(dict(credential=["Credential not found or deleted."])) - if obj.job_type != PERM_INVENTORY_SCAN and (obj.project is None or not obj.project.active): raise serializers.ValidationError(dict(errors=["Job Template Project is missing or undefined"])) if obj.inventory is None or not obj.inventory.active: raise serializers.ValidationError(dict(errors=["Job Template Inventory is missing or undefined"])) - return attrs class AdHocCommandSerializer(UnifiedJobSerializer): @@ -2010,3 +2016,30 @@ class AuthTokenSerializer(serializers.Serializer): raise serializers.ValidationError('Unable to login with provided credentials.') else: raise serializers.ValidationError('Must include "username" and "password"') + + +class FactVersionSerializer(MongoEngineModelSerializer): + related = serializers.SerializerMethodField('get_related') + + class Meta: + model = FactVersion + fields = ('related', 'module', 'timestamp',) + + def get_related(self, obj): + host_obj = self.context.get('host_obj') + res = {} + params = { + 'datetime': timestamp_apiformat(obj.timestamp), + 'module': obj.module, + } + res.update(dict( + fact_view = build_url('api:host_fact_compare_view', args=(host_obj.pk,), get=params), + )) + return res + +class FactSerializer(MongoEngineModelSerializer): + + class Meta: + model = Fact + depth = 2 + fields = ('timestamp', 'host', 'module', 'fact') diff --git a/awx/api/urls.py b/awx/api/urls.py index 16d5ddb2ea..0351be1ac8 100644 --- a/awx/api/urls.py +++ b/awx/api/urls.py @@ -75,6 +75,7 @@ inventory_urls = patterns('awx.api.views', url(r'^(?P[0-9]+)/activity_stream/$', 'inventory_activity_stream_list'), url(r'^(?P[0-9]+)/scan_job_templates/$', 'inventory_scan_job_template_list'), url(r'^(?P[0-9]+)/ad_hoc_commands/$', 'inventory_ad_hoc_commands_list'), + url(r'^(?P[0-9]+)/single_fact/$', 'inventory_single_fact_view'), ) host_urls = patterns('awx.api.views', @@ -89,6 +90,9 @@ host_urls = patterns('awx.api.views', url(r'^(?P[0-9]+)/inventory_sources/$', 'host_inventory_sources_list'), url(r'^(?P[0-9]+)/ad_hoc_commands/$', 'host_ad_hoc_commands_list'), url(r'^(?P[0-9]+)/ad_hoc_command_events/$', 'host_ad_hoc_command_events_list'), + url(r'^(?P[0-9]+)/single_fact/$', 'host_single_fact_view'), + url(r'^(?P[0-9]+)/fact_versions/$', 'host_fact_versions_list'), + url(r'^(?P[0-9]+)/fact_view/$', 'host_fact_compare_view'), ) group_urls = patterns('awx.api.views', @@ -104,6 +108,7 @@ group_urls = patterns('awx.api.views', url(r'^(?P[0-9]+)/activity_stream/$', 'group_activity_stream_list'), url(r'^(?P[0-9]+)/inventory_sources/$', 'group_inventory_sources_list'), url(r'^(?P[0-9]+)/ad_hoc_commands/$', 'group_ad_hoc_commands_list'), + url(r'^(?P[0-9]+)/single_fact/$', 'group_single_fact_view'), ) inventory_source_urls = patterns('awx.api.views', diff --git a/awx/api/views.py b/awx/api/views.py index e52dd77009..20a6681d36 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -46,6 +46,7 @@ from awx.main.access import get_user_queryset from awx.main.ha import is_ha_environment from awx.api.authentication import TaskAuthentication from awx.api.utils.decorators import paginated +from awx.api.filters import MongoFilterBackend from awx.api.generics import get_view_name from awx.api.generics import * # noqa from awx.main.models import * # noqa @@ -53,6 +54,7 @@ from awx.main.utils import * # noqa from awx.api.permissions import * # noqa from awx.api.renderers import * # noqa from awx.api.serializers import * # noqa +from awx.fact.models import * # noqa def api_exception_handler(exc): ''' @@ -922,6 +924,27 @@ class InventoryScanJobTemplateList(SubListAPIView): qs = self.request.user.get_queryset(self.model) return qs.filter(job_type=PERM_INVENTORY_SCAN, inventory=parent) +class InventorySingleFactView(MongoAPIView): + + model = Fact + parent_model = Inventory + new_in_220 = True + serializer_class = FactSerializer + filter_backends = (MongoFilterBackend,) + + def get(self, request, *args, **kwargs): + fact_key = request.QUERY_PARAMS.get("fact_key", None) + fact_value = request.QUERY_PARAMS.get("fact_value", None) + datetime_spec = request.QUERY_PARAMS.get("timestamp", None) + module_spec = request.QUERY_PARAMS.get("module", None) + + if fact_key is None or fact_value is None or module_spec is None: + return Response({"error": "Missing fields"}, status=status.HTTP_400_BAD_REQUEST) + datetime_actual = dateutil.parser.parse(datetime_spec) if datetime_spec is not None else now() + inventory_obj = self.get_parent_object() + fact_data = Fact.get_single_facts([h.name for h in inventory_obj.hosts.all()], fact_key, fact_value, datetime_actual, module_spec) + return Response(dict(results=FactSerializer(fact_data).data if fact_data is not None else [])) + class HostList(ListCreateAPIView): @@ -986,6 +1009,83 @@ class HostActivityStreamList(SubListAPIView): qs = self.request.user.get_queryset(self.model) return qs.filter(Q(host=parent) | Q(inventory=parent.inventory)) +class HostFactVersionsList(MongoListAPIView): + + serializer_class = FactVersionSerializer + parent_model = Host + new_in_220 = True + filter_backends = (MongoFilterBackend,) + + def get_queryset(self): + from_spec = self.request.QUERY_PARAMS.get('from', None) + to_spec = self.request.QUERY_PARAMS.get('to', None) + module_spec = self.request.QUERY_PARAMS.get('module', None) + + host = self.get_parent_object() + self.check_parent_access(host) + + try: + fact_host = FactHost.objects.get(hostname=host.name) + except FactHost.DoesNotExist: + return None + + kv = { + 'host': fact_host.id, + } + if module_spec is not None: + kv['module'] = module_spec + if from_spec is not None: + from_actual = dateutil.parser.parse(from_spec) + kv['timestamp__gt'] = from_actual + if to_spec is not None: + to_actual = dateutil.parser.parse(to_spec) + kv['timestamp__lte'] = to_actual + + return FactVersion.objects.filter(**kv).order_by("-timestamp") + + def list(self, *args, **kwargs): + queryset = self.get_queryset() or [] + serializer = FactVersionSerializer(queryset, many=True, context=dict(host_obj=self.get_parent_object())) + return Response(dict(results=serializer.data)) + +class HostSingleFactView(MongoAPIView): + + model = Fact + parent_model = Host + new_in_220 = True + serializer_class = FactSerializer + filter_backends = (MongoFilterBackend,) + + def get(self, request, *args, **kwargs): + fact_key = request.QUERY_PARAMS.get("fact_key", None) + fact_value = request.QUERY_PARAMS.get("fact_value", None) + datetime_spec = request.QUERY_PARAMS.get("timestamp", None) + module_spec = request.QUERY_PARAMS.get("module", None) + + if fact_key is None or fact_value is None or module_spec is None: + return Response({"error": "Missing fields"}, status=status.HTTP_400_BAD_REQUEST) + datetime_actual = dateutil.parser.parse(datetime_spec) if datetime_spec is not None else now() + host_obj = self.get_parent_object() + fact_data = Fact.get_single_facts([host_obj.name], fact_key, fact_value, datetime_actual, module_spec) + return Response(dict(results=FactSerializer(fact_data).data if fact_data is not None else [])) + +class HostFactCompareView(MongoAPIView): + + new_in_220 = True + parent_model = Host + serializer_class = FactSerializer + filter_backends = (MongoFilterBackend,) + + def get(self, request, *args, **kwargs): + datetime_spec = request.QUERY_PARAMS.get('datetime', None) + module_spec = request.QUERY_PARAMS.get('module', "ansible") + datetime_actual = dateutil.parser.parse(datetime_spec) if datetime_spec is not None else now() + + host_obj = self.get_parent_object() + fact_entry = Fact.get_host_version(host_obj.name, datetime_actual, module_spec) + host_data = FactSerializer(fact_entry).data if fact_entry is not None else {} + + return Response(host_data) class GroupList(ListCreateAPIView): @@ -1125,6 +1225,28 @@ class GroupDetail(RetrieveUpdateDestroyAPIView): obj.mark_inactive_recursive() return Response(status=status.HTTP_204_NO_CONTENT) + +class GroupSingleFactView(MongoAPIView): + + model = Fact + parent_model = Group + new_in_220 = True + serializer_class = FactSerializer + filter_backends = (MongoFilterBackend,) + + def get(self, request, *args, **kwargs): + fact_key = request.QUERY_PARAMS.get("fact_key", None) + fact_value = request.QUERY_PARAMS.get("fact_value", None) + datetime_spec = request.QUERY_PARAMS.get("timestamp", None) + module_spec = request.QUERY_PARAMS.get("module", None) + + if fact_key is None or fact_value is None or module_spec is None: + return Response({"error": "Missing fields"}, status=status.HTTP_400_BAD_REQUEST) + datetime_actual = dateutil.parser.parse(datetime_spec) if datetime_spec is not None else now() + group_obj = self.get_parent_object() + fact_data = Fact.get_single_facts([h.name for h in group_obj.hosts.all()], fact_key, fact_value, datetime_actual, module_spec) + return Response(dict(results=FactSerializer(fact_data).data if fact_data is not None else [])) + class InventoryGroupsList(SubListCreateAttachDetachAPIView): model = Group diff --git a/awx/fact/__init__.py b/awx/fact/__init__.py index f9d5796ca2..cc9b260832 100644 --- a/awx/fact/__init__.py +++ b/awx/fact/__init__.py @@ -14,7 +14,7 @@ logger = logging.getLogger('awx.fact') # Connect to Mongo try: - connect(settings.MONGO_DB) + connect(settings.MONGO_DB, tz_aware=settings.USE_TZ) register_key_transform(get_db()) except ConnectionError: logger.warn('Failed to establish connect to MongoDB "%s"' % (settings.MONGO_DB)) diff --git a/awx/fact/models/fact.py b/awx/fact/models/fact.py index e3705ee493..e9a7e30cf7 100644 --- a/awx/fact/models/fact.py +++ b/awx/fact/models/fact.py @@ -78,7 +78,7 @@ class Fact(Document): } try: - facts = Fact.objects.filter(**kv) + facts = Fact.objects.filter(**kv).order_by("-timestamp") if not facts: return None return facts[0] @@ -97,41 +97,85 @@ class Fact(Document): 'module': module, } - return FactVersion.objects.filter(**kv).values_list('timestamp') + return FactVersion.objects.filter(**kv).order_by("-timestamp").values_list('timestamp') @staticmethod - def get_single_facts(hostnames, fact_key, timestamp, module): - host_ids = FactHost.objects.filter(hostname__in=hostnames).values_list('id') - if not host_ids or len(host_ids) == 0: - return None - + def get_single_facts(hostnames, fact_key, fact_value, timestamp, module): kv = { - 'host__in': host_ids, - 'timestamp__lte': timestamp, - 'module': module, - } - facts = FactVersion.objects.filter(**kv).values_list('fact') - if not facts or len(facts) == 0: - return None - # TODO: Make sure the below doesn't trigger a query to get the fact record - # It's unclear as to if mongoengine will query the full fact when the id is referenced. - # This is not a logic problem, but a performance problem. - fact_ids = [fact.id for fact in facts] - - project = { - '$project': { - 'host': 1, - 'fact.%s' % fact_key: 1, + 'hostname': { + '$in': hostnames, } } - facts = Fact.objects.filter(id__in=fact_ids).aggregate(project) - return facts + fields = { + '_id': 1 + } + host_ids = FactHost._get_collection().find(kv, fields) + if not host_ids or host_ids.count() == 0: + return None + # TODO: use mongo to transform [{_id: <>}, {_id: <>},...] into [_id, _id,...] + host_ids = [e['_id'] for e in host_ids] + pipeline = [] + match = { + 'host': { + '$in': host_ids + }, + 'timestamp': { + '$lte': timestamp + }, + 'module': module + } + sort = { + 'timestamp': -1 + } + group = { + '_id': '$host', + 'timestamp': { + '$first': '$timestamp' + }, + 'fact': { + '$first': '$fact' + } + } + project = { + '_id': 0, + 'fact': 1, + } + pipeline.append({'$match': match}) # noqa + pipeline.append({'$sort': sort}) # noqa + pipeline.append({'$group': group}) # noqa + pipeline.append({'$project': project}) # noqa + q = FactVersion._get_collection().aggregate(pipeline) + if not q or 'result' not in q or len(q['result']) == 0: + return None + # TODO: use mongo to transform [{fact: <>}, {fact: <>},...] into [fact, fact,...] + fact_ids = [fact['fact'] for fact in q['result']] + + kv = { + 'fact.%s' % fact_key : fact_value, + '_id': { + '$in': fact_ids + } + } + fields = { + 'fact.%s.$' % fact_key : 1, + 'host': 1, + 'timestamp': 1, + 'module': 1, + } + facts = Fact._get_collection().find(kv, fields) + #fact_objs = [Fact(**f) for f in facts] + # Translate pymongo python structure to mongoengine Fact object + fact_objs = [] + for f in facts: + f['id'] = f.pop('_id') + fact_objs.append(Fact(**f)) + return fact_objs class FactVersion(Document): timestamp = DateTimeField(required=True) host = ReferenceField(FactHost, required=True) - module = StringField(max_length=50, required=True) + module = StringField(max_length=50, required=True) fact = ReferenceField(Fact, required=True) # TODO: Consider using hashed index on module. django-mongo may not support this but # executing raw js will @@ -141,4 +185,3 @@ class FactVersion(Document): 'module' ] } - \ No newline at end of file diff --git a/awx/fact/tests/__init__.py b/awx/fact/tests/__init__.py index d7187f3928..d276b01707 100644 --- a/awx/fact/tests/__init__.py +++ b/awx/fact/tests/__init__.py @@ -5,3 +5,4 @@ from __future__ import absolute_import from .models import * # noqa from .utils import * # noqa +from .base import * # noqa diff --git a/awx/fact/tests/base.py b/awx/fact/tests/base.py new file mode 100644 index 0000000000..48a708acd6 --- /dev/null +++ b/awx/fact/tests/base.py @@ -0,0 +1,200 @@ +# Copyright (c) 2015 Ansible, Inc. +# All Rights Reserved + +# Python +from __future__ import absolute_import +from django.utils.timezone import now + +# Django +from django.conf import settings +import django + +# MongoEngine +from mongoengine.connection import get_db, ConnectionError + +# AWX +from awx.fact.models.fact import * # noqa + +TEST_FACT_ANSIBLE = { + "ansible_swapfree_mb" : 4092, + "ansible_default_ipv6" : { + + }, + "ansible_distribution_release" : "trusty", + "ansible_system_vendor" : "innotek GmbH", + "ansible_os_family" : "Debian", + "ansible_all_ipv4_addresses" : [ + "192.168.1.145" + ], + "ansible_lsb" : { + "release" : "14.04", + "major_release" : "14", + "codename" : "trusty", + "id" : "Ubuntu", + "description" : "Ubuntu 14.04.2 LTS" + }, +} + +TEST_FACT_PACKAGES = [ + { + "name": "accountsservice", + "architecture": "amd64", + "source": "apt", + "version": "0.6.35-0ubuntu7.1" + }, + { + "name": "acpid", + "architecture": "amd64", + "source": "apt", + "version": "1:2.0.21-1ubuntu2" + }, + { + "name": "adduser", + "architecture": "all", + "source": "apt", + "version": "3.113+nmu3ubuntu3" + }, +] + +TEST_FACT_SERVICES = [ + { + "source" : "upstart", + "state" : "waiting", + "name" : "ureadahead-other", + "goal" : "stop" + }, + { + "source" : "upstart", + "state" : "running", + "name" : "apport", + "goal" : "start" + }, + { + "source" : "upstart", + "state" : "waiting", + "name" : "console-setup", + "goal" : "stop" + }, +] + + +class MongoDBRequired(django.test.TestCase): + def setUp(self): + # Drop mongo database + try: + self.db = get_db() + self.db.connection.drop_database(settings.MONGO_DB) + except ConnectionError: + self.skipTest('MongoDB connection failed') + +class BaseFactTestMixin(MongoDBRequired): + pass + +class BaseFactTest(BaseFactTestMixin, MongoDBRequired): + pass + +class FactScanBuilder(object): + + def __init__(self): + self.facts_data = {} + self.hostname_data = [] + + self.host_objs = [] + self.fact_objs = [] + self.version_objs = [] + self.timestamps = [] + + def add_fact(self, module, facts): + self.facts_data[module] = facts + + def add_hostname(self, hostname): + self.hostname_data.append(hostname) + + def build(self, scan_count, host_count): + if len(self.facts_data) == 0: + raise RuntimeError("No fact data to build populate scans. call add_fact()") + if (len(self.hostname_data) > 0 and len(self.hostname_data) != host_count): + raise RuntimeError("Registered number of hostnames %d does not match host_count %d" % (len(self.hostname_data), host_count)) + + if len(self.hostname_data) == 0: + self.hostname_data = ['hostname_%s' % i for i in range(0, host_count)] + + self.host_objs = [FactHost(hostname=hostname).save() for hostname in self.hostname_data] + + for i in range(0, scan_count): + scan = {} + scan_version = {} + timestamp = now().replace(year=2015 - i, microsecond=0) + for module in self.facts_data: + fact_objs = [] + version_objs = [] + for host in self.host_objs: + (fact_obj, version_obj) = Fact.add_fact(timestamp=timestamp, + host=host, + module=module, + fact=self.facts_data[module]) + fact_objs.append(fact_obj) + version_objs.append(version_obj) + scan[module] = fact_objs + scan_version[module] = version_objs + self.fact_objs.append(scan) + self.version_objs.append(scan_version) + self.timestamps.append(timestamp) + + + def get_scan(self, index, module=None): + res = None + res = self.fact_objs[index] + if module: + res = res[module] + return res + + def get_scans(self, index_start=None, index_end=None): + if index_start is None: + index_start = 0 + if index_end is None: + index_end = len(self.fact_objs) + return self.fact_objs[index_start:index_end] + + def get_scan_version(self, index, module=None): + res = None + res = self.version_objs[index] + if module: + res = res[module] + return res + + def get_scan_versions(self, index_start=None, index_end=None): + if index_start is None: + index_start = 0 + if index_end is None: + index_end = len(self.version_objs) + return self.version_objs[index_start:index_end] + + def get_hostname(self, index): + return self.host_objs[index].hostname + + def get_hostnames(self, index_start=None, index_end=None): + if index_start is None: + index_start = 0 + if index_end is None: + index_end = len(self.host_objs) + + return [self.host_objs[i].hostname for i in range(index_start, index_end)] + + + def get_scan_count(self): + return len(self.fact_objs) + + def get_host_count(self): + return len(self.host_objs) + + def get_timestamp(self, index): + return self.timestamps[index] + + def get_timestamps(self, index_start=None, index_end=None): + if not index_start: + index_start = 0 + if not index_end: + len(self.timestamps) + return self.timestamps[index_start:index_end] + diff --git a/awx/fact/tests/models/fact/base.py b/awx/fact/tests/models/fact/base.py deleted file mode 100644 index 3d8c4653f0..0000000000 --- a/awx/fact/tests/models/fact/base.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2015 Ansible, Inc. -# All Rights Reserved - -# Python -from __future__ import absolute_import -from awx.main.tests.base import BaseTest, MongoDBRequired - -# AWX -from awx.fact.models.fact import * # noqa - -''' -Helper functions (i.e. create_host_document) expect the structure: -{ - 'hostname': 'hostname1', - 'add_fact_data': { - 'timestamp': datetime.now(), - 'host': None, - 'module': 'packages', - 'fact': ... - } -} -''' -class BaseFactTest(BaseTest, MongoDBRequired): - - @staticmethod - def _normalize_timestamp(timestamp): - return timestamp.replace(microsecond=0) - - @staticmethod - def normalize_timestamp(data): - data['add_fact_data']['timestamp'] = BaseFactTest._normalize_timestamp(data['add_fact_data']['timestamp']) - - def create_host_document(self, data): - data['add_fact_data']['host'] = FactHost(hostname=data['hostname']).save() diff --git a/awx/fact/tests/models/fact/fact_get_single_facts.py b/awx/fact/tests/models/fact/fact_get_single_facts.py index 9329f0d653..ef1d4befa9 100644 --- a/awx/fact/tests/models/fact/fact_get_single_facts.py +++ b/awx/fact/tests/models/fact/fact_get_single_facts.py @@ -3,83 +3,37 @@ # Python from __future__ import absolute_import -from datetime import datetime -from copy import deepcopy # Django # AWX from awx.fact.models.fact import * # noqa -from .base import BaseFactTest - -__all__ = ['FactGetSingleFactsTest'] - -TEST_FACT_DATA = { - 'hostname': 'hostname_%d', - 'add_fact_data': { - 'timestamp': datetime.now(), - 'host': None, - 'module': 'packages', - 'fact': { - "accountsservice": [ - { - "architecture": "amd64", - "name": "accountsservice", - "source": "apt", - "version": "0.6.35-0ubuntu7.1" - } - ], - "acpid": [ - { - "architecture": "amd64", - "name": "acpid", - "source": "apt", - "version": "1:2.0.21-1ubuntu2" - } - ], - "adduser": [ - { - "architecture": "all", - "name": "adduser", - "source": "apt", - "version": "3.113+nmu3ubuntu3" - } - ], - }, - } -} +from awx.fact.tests.base import BaseFactTest, FactScanBuilder, TEST_FACT_PACKAGES +__all__ = ['FactGetSingleFactsTest', 'FactGetSingleFactsMultipleScansTest',] class FactGetSingleFactsTest(BaseFactTest): - def create_fact_scans_unique_hosts(self, host_count): - self.fact_data = [] - self.fact_objs = [] - self.hostnames = [] - for i in range(1, host_count + 1): - fact_data = deepcopy(TEST_FACT_DATA) - fact_data['hostname'] = fact_data['hostname'] % (i) - fact_data['add_fact_data']['timestamp'] = datetime.now().replace(year=2015 - i) - BaseFactTest.normalize_timestamp(fact_data) - - self.create_host_document(fact_data) - (fact_obj, version_obj) = Fact.add_fact(**fact_data['add_fact_data']) - - self.fact_data.append(fact_data) - self.fact_objs.append(fact_obj) - self.hostnames.append(fact_data['hostname']) - def setUp(self): super(FactGetSingleFactsTest, self).setUp() - self.host_count = 20 - self.create_fact_scans_unique_hosts(self.host_count) + self.builder = FactScanBuilder() + self.builder.add_fact('packages', TEST_FACT_PACKAGES) + self.builder.add_fact('nested', TEST_FACT_PACKAGES) + self.builder.build(scan_count=1, host_count=20) def check_query_results(self, facts_known, facts): - # Transpose facts to a dict with key _id + self.assertIsNotNone(facts) + self.assertEqual(len(facts_known), len(facts), "More or less facts found than expected") + # Ensure only 'acpid' is returned + for fact in facts: + self.assertEqual(len(fact.fact), 1) + self.assertEqual(fact.fact[0]['name'], 'acpid') + + # Transpose facts to a dict with key id count = 0 facts_dict = {} for fact in facts: count += 1 - facts_dict[fact['_id']] = fact + facts_dict[fact.id] = fact self.assertEqual(count, len(facts_known)) # For each fact that we put into the database on setup, @@ -87,20 +41,56 @@ class FactGetSingleFactsTest(BaseFactTest): for fact_known in facts_known: key = fact_known.id self.assertIn(key, facts_dict) - self.assertEqual(facts_dict[key]['fact']['acpid'], fact_known.fact['acpid']) - self.assertEqual(facts_dict[key]['host'], fact_known.host.id) + self.assertEqual(len(facts_dict[key].fact), 1) - def test_get_single_facts_ok(self): - timestamp = datetime.now().replace(year=2016) - facts = Fact.get_single_facts(self.hostnames, 'acpid', timestamp, 'packages') + def check_query_results_nested(self, facts): self.assertIsNotNone(facts) + for fact in facts: + self.assertEqual(len(fact.fact), 1) + self.assertEqual(fact.fact['nested'][0]['name'], 'acpid') - self.check_query_results(self.fact_objs, facts) + def test_single_host(self): + facts = Fact.get_single_facts(self.builder.get_hostnames(0, 1), 'name', 'acpid', self.builder.get_timestamp(0), 'packages') - def test_get_single_facts_subset_by_timestamp(self): - timestamp = datetime.now().replace(year=2010) - facts = Fact.get_single_facts(self.hostnames, 'acpid', timestamp, 'packages') - self.assertIsNotNone(facts) + self.check_query_results(self.builder.get_scan(0, 'packages')[:1], facts) + + def test_all(self): + facts = Fact.get_single_facts(self.builder.get_hostnames(), 'name', 'acpid', self.builder.get_timestamp(0), 'packages') + + self.check_query_results(self.builder.get_scan(0, 'packages'), facts) + + def test_subset_hosts(self): + host_count = (self.builder.get_host_count() / 2) + facts = Fact.get_single_facts(self.builder.get_hostnames(0, host_count), 'name', 'acpid', self.builder.get_timestamp(0), 'packages') + + self.check_query_results(self.builder.get_scan(0, 'packages')[:host_count], facts) + + def test_get_single_facts_nested(self): + facts = Fact.get_single_facts(self.builder.get_hostnames(), 'nested.name', 'acpid', self.builder.get_timestamp(0), 'packages') + + self.check_query_results_nested(facts) + +class FactGetSingleFactsMultipleScansTest(BaseFactTest): + def setUp(self): + super(FactGetSingleFactsMultipleScansTest, self).setUp() + self.builder = FactScanBuilder() + self.builder.add_fact('packages', TEST_FACT_PACKAGES) + self.builder.build(scan_count=10, host_count=10) + + def test_1_host(self): + facts = Fact.get_single_facts(self.builder.get_hostnames(0, 1), 'name', 'acpid', self.builder.get_timestamp(0), 'packages') + self.assertEqual(len(facts), 1) + self.assertEqual(facts[0], self.builder.get_scan(0, 'packages')[0]) + + def test_multiple_hosts(self): + facts = Fact.get_single_facts(self.builder.get_hostnames(0, 3), 'name', 'acpid', self.builder.get_timestamp(0), 'packages') + self.assertEqual(len(facts), 3) + for i, fact in enumerate(facts): + self.assertEqual(fact, self.builder.get_scan(0, 'packages')[i]) + + def test_middle_of_timeline(self): + facts = Fact.get_single_facts(self.builder.get_hostnames(0, 3), 'name', 'acpid', self.builder.get_timestamp(4), 'packages') + self.assertEqual(len(facts), 3) + for i, fact in enumerate(facts): + self.assertEqual(fact, self.builder.get_scan(4, 'packages')[i]) - self.check_query_results(self.fact_objs[4:], facts) - \ No newline at end of file diff --git a/awx/fact/tests/models/fact/fact_simple.py b/awx/fact/tests/models/fact/fact_simple.py index 9103416c22..587fbf736e 100644 --- a/awx/fact/tests/models/fact/fact_simple.py +++ b/awx/fact/tests/models/fact/fact_simple.py @@ -3,78 +3,29 @@ # Python from __future__ import absolute_import -from datetime import datetime -from copy import deepcopy +from django.utils.timezone import now +from dateutil.relativedelta import relativedelta # Django # AWX from awx.fact.models.fact import * # noqa -from .base import BaseFactTest +from awx.fact.tests.base import BaseFactTest, FactScanBuilder, TEST_FACT_PACKAGES __all__ = ['FactHostTest', 'FactTest', 'FactGetHostVersionTest', 'FactGetHostTimelineTest'] -TEST_FACT_DATA = { - 'hostname': 'hostname1', - 'add_fact_data': { - 'timestamp': datetime.now(), - 'host': None, - 'module': 'packages', - 'fact': { - "accountsservice": [ - { - "architecture": "amd64", - "name": "accountsservice", - "source": "apt", - "version": "0.6.35-0ubuntu7.1" - } - ], - "acpid": [ - { - "architecture": "amd64", - "name": "acpid", - "source": "apt", - "version": "1:2.0.21-1ubuntu2" - } - ], - "adduser": [ - { - "architecture": "all", - "name": "adduser", - "source": "apt", - "version": "3.113+nmu3ubuntu3" - } - ], - }, - } -} -# Strip off microseconds because mongo has less precision -BaseFactTest.normalize_timestamp(TEST_FACT_DATA) - -def create_fact_scans(count=1): - timestamps = [] - for i in range(0, count): - data = deepcopy(TEST_FACT_DATA) - t = datetime.now().replace(year=2015 - i, microsecond=0) - data['add_fact_data']['timestamp'] = t - (f, v) = Fact.add_fact(**data['add_fact_data']) - timestamps.append(t) - - return timestamps - - class FactHostTest(BaseFactTest): def test_create_host(self): - host = FactHost(hostname=TEST_FACT_DATA['hostname']) + host = FactHost(hostname='hosty') host.save() - host = FactHost.objects.get(hostname=TEST_FACT_DATA['hostname']) + host = FactHost.objects.get(hostname='hosty') self.assertIsNotNone(host, "Host added but not found") - self.assertEqual(TEST_FACT_DATA['hostname'], host.hostname, "Gotten record hostname does not match expected hostname") + self.assertEqual('hosty', host.hostname, "Gotten record hostname does not match expected hostname") # Ensure an error is raised for .get() that doesn't match a record. def test_get_host_id_no_result(self): - host = FactHost(hostname=TEST_FACT_DATA['hostname']) + host = FactHost(hostname='hosty') host.save() self.assertRaises(FactHost.DoesNotExist, FactHost.objects.get, hostname='doesnotexist') @@ -82,70 +33,64 @@ class FactHostTest(BaseFactTest): class FactTest(BaseFactTest): def setUp(self): super(FactTest, self).setUp() - self.create_host_document(TEST_FACT_DATA) def test_add_fact(self): - (f_obj, v_obj) = Fact.add_fact(**TEST_FACT_DATA['add_fact_data']) + timestamp = now().replace(microsecond=0) + host = FactHost(hostname="hosty").save() + (f_obj, v_obj) = Fact.add_fact(host=host, timestamp=timestamp, module='packages', fact=TEST_FACT_PACKAGES) f = Fact.objects.get(id=f_obj.id) v = FactVersion.objects.get(id=v_obj.id) self.assertEqual(f.id, f_obj.id) - self.assertEqual(f.module, TEST_FACT_DATA['add_fact_data']['module']) - self.assertEqual(f.fact, TEST_FACT_DATA['add_fact_data']['fact']) - self.assertEqual(f.timestamp, TEST_FACT_DATA['add_fact_data']['timestamp']) + self.assertEqual(f.module, 'packages') + self.assertEqual(f.fact, TEST_FACT_PACKAGES) + self.assertEqual(f.timestamp, timestamp) # host relationship created - self.assertEqual(f.host.id, TEST_FACT_DATA['add_fact_data']['host'].id) + self.assertEqual(f.host.id, host.id) # version created and related self.assertEqual(v.id, v_obj.id) - self.assertEqual(v.timestamp, TEST_FACT_DATA['add_fact_data']['timestamp']) - self.assertEqual(v.host.id, TEST_FACT_DATA['add_fact_data']['host'].id) + self.assertEqual(v.timestamp, timestamp) + self.assertEqual(v.host.id, host.id) self.assertEqual(v.fact.id, f_obj.id) - self.assertEqual(v.fact.module, TEST_FACT_DATA['add_fact_data']['module']) + self.assertEqual(v.fact.module, 'packages') class FactGetHostVersionTest(BaseFactTest): def setUp(self): super(FactGetHostVersionTest, self).setUp() - self.create_host_document(TEST_FACT_DATA) - - self.t1 = datetime.now().replace(second=1, microsecond=0) - self.t2 = datetime.now().replace(second=2, microsecond=0) - data = deepcopy(TEST_FACT_DATA) - data['add_fact_data']['timestamp'] = self.t1 - (self.f1, self.v1) = Fact.add_fact(**data['add_fact_data']) - data = deepcopy(TEST_FACT_DATA) - data['add_fact_data']['timestamp'] = self.t2 - (self.f2, self.v2) = Fact.add_fact(**data['add_fact_data']) + self.builder = FactScanBuilder() + self.builder.add_fact('packages', TEST_FACT_PACKAGES) + self.builder.build(scan_count=2, host_count=1) def test_get_host_version_exact_timestamp(self): - fact = Fact.get_host_version(hostname=TEST_FACT_DATA['hostname'], timestamp=self.t1, module=TEST_FACT_DATA['add_fact_data']['module']) - self.assertIsNotNone(fact, "Set of Facts not found") - self.assertEqual(self.f1.id, fact.id) - self.assertEqual(self.f1.fact, fact.fact) + fact_known = self.builder.get_scan(0, 'packages')[0] + fact = Fact.get_host_version(hostname=self.builder.get_hostname(0), timestamp=self.builder.get_timestamp(0), module='packages') + self.assertIsNotNone(fact) + self.assertEqual(fact_known, fact) def test_get_host_version_lte_timestamp(self): - t3 = datetime.now().replace(second=3, microsecond=0) - fact = Fact.get_host_version(hostname=TEST_FACT_DATA['hostname'], timestamp=t3, module=TEST_FACT_DATA['add_fact_data']['module']) - self.assertEqual(self.f1.id, fact.id) - self.assertEqual(self.f1.fact, fact.fact) + timestamp = self.builder.get_timestamp(0) + relativedelta(days=1) + fact_known = self.builder.get_scan(0, 'packages')[0] + fact = Fact.get_host_version(hostname=self.builder.get_hostname(0), timestamp=timestamp, module='packages') + self.assertIsNotNone(fact) + self.assertEqual(fact_known, fact) def test_get_host_version_none(self): - t3 = deepcopy(self.t1).replace(second=0) - fact = Fact.get_host_version(hostname=TEST_FACT_DATA['hostname'], timestamp=t3, module=TEST_FACT_DATA['add_fact_data']['module']) + timestamp = self.builder.get_timestamp(0) - relativedelta(years=20) + fact = Fact.get_host_version(hostname=self.builder.get_hostname(0), timestamp=timestamp, module='packages') self.assertIsNone(fact) class FactGetHostTimelineTest(BaseFactTest): def setUp(self): super(FactGetHostTimelineTest, self).setUp() - self.create_host_document(TEST_FACT_DATA) - - self.scans = 20 - self.timestamps = create_fact_scans(self.scans) + self.builder = FactScanBuilder() + self.builder.add_fact('packages', TEST_FACT_PACKAGES) + self.builder.build(scan_count=20, host_count=1) def test_get_host_timeline_ok(self): - timestamps = Fact.get_host_timeline(hostname=TEST_FACT_DATA['hostname'], module=TEST_FACT_DATA['add_fact_data']['module']) + timestamps = Fact.get_host_timeline(hostname=self.builder.get_hostname(0), module='packages') self.assertIsNotNone(timestamps) - self.assertEqual(len(timestamps), len(self.timestamps)) - for i in range(0, self.scans): - self.assertEqual(timestamps[i], self.timestamps[i]) + self.assertEqual(len(timestamps), self.builder.get_scan_count()) + for i in range(0, self.builder.get_scan_count()): + self.assertEqual(timestamps[i], self.builder.get_timestamp(i)) diff --git a/awx/fact/tests/models/fact/fact_transform.py b/awx/fact/tests/models/fact/fact_transform.py index 6661f81179..82c27fae9c 100644 --- a/awx/fact/tests/models/fact/fact_transform.py +++ b/awx/fact/tests/models/fact/fact_transform.py @@ -13,38 +13,45 @@ import pymongo # AWX from awx.fact.models.fact import * # noqa -from .base import BaseFactTest +from awx.fact.tests.base import BaseFactTest __all__ = ['FactTransformTest', 'FactTransformUpdateTest',] -TEST_FACT_DATA = { - 'hostname': 'hostname1', - 'add_fact_data': { - 'timestamp': datetime.now(), - 'host': None, - 'module': 'packages', - 'fact': { - "acpid3.4": [ - { - "version": "1:2.0.21-1ubuntu2", - "deeper.key": "some_value" - } - ], - "adduser.2": [ - { - "source": "apt", - "version": "3.113+nmu3ubuntu3" - } - ], - "what.ever." : { - "shallowish.key": "some_shallow_value" - } - }, +TEST_FACT_PACKAGES_WITH_DOTS = [ + { + "name": "acpid3.4", + "version": "1:2.0.21-1ubuntu2", + "deeper.key": "some_value" + }, + { + "name": "adduser.2", + "source": "apt", + "version": "3.113+nmu3ubuntu3" + }, + { + "what.ever." : { + "shallowish.key": "some_shallow_value" + } } -} -# Strip off microseconds because mongo has less precision -BaseFactTest.normalize_timestamp(TEST_FACT_DATA) +] +TEST_FACT_PACKAGES_WITH_DOLLARS = [ + { + "name": "acpid3$4", + "version": "1:2.0.21-1ubuntu2", + "deeper.key": "some_value" + }, + { + "name": "adduser$2", + "source": "apt", + "version": "3.113+nmu3ubuntu3" + }, + { + "what.ever." : { + "shallowish.key": "some_shallow_value" + } + } +] class FactTransformTest(BaseFactTest): def setUp(self): super(FactTransformTest, self).setUp() @@ -52,16 +59,16 @@ class FactTransformTest(BaseFactTest): self.client = pymongo.MongoClient('localhost', 27017) self.db2 = self.client[settings.MONGO_DB] - self.create_host_document(TEST_FACT_DATA) + self.timestamp = datetime.now().replace(microsecond=0) def setup_create_fact_dot(self): - self.data = TEST_FACT_DATA - self.f = Fact(**TEST_FACT_DATA['add_fact_data']) + self.host = FactHost(hostname='hosty').save() + self.f = Fact(timestamp=self.timestamp, module='packages', fact=TEST_FACT_PACKAGES_WITH_DOTS, host=self.host) self.f.save() def setup_create_fact_dollar(self): - self.data = TEST_FACT_DATA - self.f = Fact(**TEST_FACT_DATA['add_fact_data']) + self.host = FactHost(hostname='hosty').save() + self.f = Fact(timestamp=self.timestamp, module='packages', fact=TEST_FACT_PACKAGES_WITH_DOLLARS, host=self.host) self.f.save() def test_fact_with_dot_serialized(self): @@ -73,17 +80,18 @@ class FactTransformTest(BaseFactTest): # Bypass mongoengine and pymongo transform to get record f_dict = self.db2['fact'].find_one(q) - self.assertIn('acpid3\uff0E4', f_dict['fact']) + self.assertIn('what\uff0Eever\uff0E', f_dict['fact'][2]) def test_fact_with_dot_serialized_pymongo(self): #self.setup_create_fact_dot() + host = FactHost(hostname='hosty').save() f = self.db['fact'].insert({ - 'hostname': TEST_FACT_DATA['hostname'], - 'fact': TEST_FACT_DATA['add_fact_data']['fact'], - 'timestamp': TEST_FACT_DATA['add_fact_data']['timestamp'], - 'host': TEST_FACT_DATA['add_fact_data']['host'].id, - 'module': TEST_FACT_DATA['add_fact_data']['module'] + 'hostname': 'hosty', + 'fact': TEST_FACT_PACKAGES_WITH_DOTS, + 'timestamp': self.timestamp, + 'host': host.id, + 'module': 'packages', }) q = { @@ -91,7 +99,7 @@ class FactTransformTest(BaseFactTest): } # Bypass mongoengine and pymongo transform to get record f_dict = self.db2['fact'].find_one(q) - self.assertIn('acpid3\uff0E4', f_dict['fact']) + self.assertIn('what\uff0Eever\uff0E', f_dict['fact'][2]) def test_fact_with_dot_deserialized_pymongo(self): self.setup_create_fact_dot() @@ -100,13 +108,13 @@ class FactTransformTest(BaseFactTest): '_id': self.f.id } f_dict = self.db['fact'].find_one(q) - self.assertIn('acpid3.4', f_dict['fact']) + self.assertIn('what.ever.', f_dict['fact'][2]) def test_fact_with_dot_deserialized(self): self.setup_create_fact_dot() f = Fact.objects.get(id=self.f.id) - self.assertIn('acpid3.4', f.fact) + self.assertIn('what.ever.', f.fact[2]) class FactTransformUpdateTest(BaseFactTest): pass diff --git a/awx/fact/tests/models/fact/fact_transform_pymongo.py b/awx/fact/tests/models/fact/fact_transform_pymongo.py index 7cf81e4650..ac7c329980 100644 --- a/awx/fact/tests/models/fact/fact_transform_pymongo.py +++ b/awx/fact/tests/models/fact/fact_transform_pymongo.py @@ -13,7 +13,7 @@ import pymongo # AWX from awx.fact.models.fact import * # noqa -from .base import BaseFactTest +from awx.fact.tests.base import BaseFactTest __all__ = ['FactSerializePymongoTest', 'FactDeserializePymongoTest',] diff --git a/awx/lib/site-packages/README b/awx/lib/site-packages/README index c50c8f14d9..cdd9729bf9 100644 --- a/awx/lib/site-packages/README +++ b/awx/lib/site-packages/README @@ -38,6 +38,8 @@ keystoneclient==1.3.0 (keystone/*) kombu==3.0.21 (kombu/*) Markdown==2.4.1 (markdown/*, excluded bin/markdown_py) mock==1.0.1 (mock.py) +mongoengine==0.9.0 (mongoengine/*) +mongoengine_rest_framework==1.5.4 (rest_framework_mongoengine/*) netaddr==0.7.14 (netaddr/*) os_client_config==0.6.0 (os_client_config/*) ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support) diff --git a/awx/lib/site-packages/rest_framework_mongoengine/__init__.py b/awx/lib/site-packages/rest_framework_mongoengine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/awx/lib/site-packages/rest_framework_mongoengine/fields.py b/awx/lib/site-packages/rest_framework_mongoengine/fields.py new file mode 100644 index 0000000000..d18fe07156 --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/fields.py @@ -0,0 +1,137 @@ +from bson.errors import InvalidId +from django.core.exceptions import ValidationError +from django.utils.encoding import smart_str +from mongoengine import dereference +from mongoengine.base.document import BaseDocument +from mongoengine.document import Document +from rest_framework import serializers +from mongoengine.fields import ObjectId +import bson + + +class MongoDocumentField(serializers.WritableField): + MAX_RECURSION_DEPTH = 5 # default value of depth + + def __init__(self, *args, **kwargs): + try: + self.model_field = kwargs.pop('model_field') + self.depth = kwargs.pop('depth', self.MAX_RECURSION_DEPTH) + except KeyError: + raise ValueError("%s requires 'model_field' kwarg" % self.type_label) + + super(MongoDocumentField, self).__init__(*args, **kwargs) + + def transform_document(self, document, depth): + data = {} + + # serialize each required field + for field in document._fields: + if hasattr(document, smart_str(field)): + # finally check for an attribute 'field' on the instance + obj = getattr(document, field) + else: + continue + + val = self.transform_object(obj, depth-1) + + if val is not None: + data[field] = val + + return data + + def transform_dict(self, obj, depth): + return dict([(key, self.transform_object(val, depth-1)) + for key, val in obj.items()]) + + def transform_object(self, obj, depth): + """ + Models to natives + Recursion for (embedded) objects + """ + if isinstance(obj, BaseDocument): + # Document, EmbeddedDocument + if depth == 0: + # Return primary key if exists, else return default text + return smart_str(getattr(obj, 'pk', 'Max recursion depth exceeded')) + return self.transform_document(obj, depth) + elif isinstance(obj, dict): + # Dictionaries + return self.transform_dict(obj, depth) + elif isinstance(obj, list): + # List + return [self.transform_object(value, depth) for value in obj] + elif obj is None: + return None + else: + return smart_str(obj) if isinstance(obj, ObjectId) else obj + + +class ReferenceField(MongoDocumentField): + + type_label = 'ReferenceField' + + def from_native(self, value): + try: + dbref = self.model_field.to_python(value) + except InvalidId: + raise ValidationError(self.error_messages['invalid']) + + instance = dereference.DeReference().__call__([dbref])[0] + + # Check if dereference was successful + if not isinstance(instance, Document): + msg = self.error_messages['invalid'] + raise ValidationError(msg) + + return instance + + def to_native(self, obj): + #if type is DBRef it means Mongo can't find the actual reference object + #prevent the JSON serializable error by setting the object to None + if type(obj) == bson.dbref.DBRef: + obj = None + return self.transform_object(obj, self.depth - 1) + + +class ListField(MongoDocumentField): + + type_label = 'ListField' + + def from_native(self, value): + return self.model_field.to_python(value) + + def to_native(self, obj): + return self.transform_object(obj, self.depth - 1) + + +class EmbeddedDocumentField(MongoDocumentField): + + type_label = 'EmbeddedDocumentField' + + def __init__(self, *args, **kwargs): + try: + self.document_type = kwargs.pop('document_type') + except KeyError: + raise ValueError("EmbeddedDocumentField requires 'document_type' kwarg") + + super(EmbeddedDocumentField, self).__init__(*args, **kwargs) + + def get_default_value(self): + return self.to_native(self.default()) + + def to_native(self, obj): + if obj is None: + return None + else: + return self.transform_object(obj, self.depth) + + def from_native(self, value): + return self.model_field.to_python(value) + + +class DynamicField(MongoDocumentField): + + type_label = 'DynamicField' + + def to_native(self, obj): + return self.model_field.to_python(obj) diff --git a/awx/lib/site-packages/rest_framework_mongoengine/generics.py b/awx/lib/site-packages/rest_framework_mongoengine/generics.py new file mode 100644 index 0000000000..679a99d255 --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/generics.py @@ -0,0 +1,150 @@ +from django.core.exceptions import ImproperlyConfigured +from rest_framework import mixins +from rest_framework.generics import GenericAPIView +from mongoengine.django.shortcuts import get_document_or_404 + + +class MongoAPIView(GenericAPIView): + """ + Mixin for views manipulating mongo documents + + """ + queryset = None + serializer_class = None + lookup_field = 'id' + + def get_queryset(self): + """ + Get the list of items for this view. + This must be an iterable, and may be a queryset. + Defaults to using `self.queryset`. + + You may want to override this if you need to provide different + querysets depending on the incoming request. + + (Eg. return a list of items that is specific to the user) + """ + if self.queryset is not None: + return self.queryset.clone() + + if self.model is not None: + return self.get_serializer().opts.model.objects.all() + + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + % self.__class__.__name__) + + def get_object(self, queryset=None): + """ + Get a document instance for read/update/delete requests. + """ + query_key = self.lookup_url_kwarg or self.lookup_field + query_kwargs = {query_key: self.kwargs[query_key]} + queryset = self.get_queryset() + + obj = get_document_or_404(queryset, **query_kwargs) + self.check_object_permissions(self.request, obj) + + return obj + + +class CreateAPIView(mixins.CreateModelMixin, + MongoAPIView): + + """ + Concrete view for creating a model instance. + """ + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + +class ListAPIView(mixins.ListModelMixin, + MongoAPIView): + """ + Concrete view for listing a queryset. + """ + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + +class ListCreateAPIView(mixins.ListModelMixin, + mixins.CreateModelMixin, + MongoAPIView): + """ + Concrete view for listing a queryset or creating a model instance. + """ + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + +class RetrieveAPIView(mixins.RetrieveModelMixin, + MongoAPIView): + """ + Concrete view for retrieving a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class UpdateAPIView(mixins.UpdateModelMixin, + MongoAPIView): + + """ + Concrete view for updating a model instance. + """ + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + +class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + MongoAPIView): + """ + Concrete view for retrieving, updating a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + +class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, + mixins.DestroyModelMixin, + MongoAPIView): + """ + Concrete view for retrieving or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + MongoAPIView): + """ + Concrete view for retrieving, updating or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) \ No newline at end of file diff --git a/awx/lib/site-packages/rest_framework_mongoengine/routers.py b/awx/lib/site-packages/rest_framework_mongoengine/routers.py new file mode 100644 index 0000000000..af12281340 --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/routers.py @@ -0,0 +1,22 @@ +from rest_framework.routers import SimpleRouter, DefaultRouter + + +class MongoRouterMixin(object): + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + model_cls = getattr(viewset, 'model', None) + assert model_cls, '`base_name` argument not specified, and could ' \ + 'not automatically determine the name from the viewset, as ' \ + 'it does not have a `.model` attribute.' + return model_cls.__name__.lower() + + +class MongoSimpleRouter(MongoRouterMixin, SimpleRouter): + pass + + +class MongoDefaultRouter(MongoSimpleRouter, DefaultRouter): + pass \ No newline at end of file diff --git a/awx/lib/site-packages/rest_framework_mongoengine/serializers.py b/awx/lib/site-packages/rest_framework_mongoengine/serializers.py new file mode 100644 index 0000000000..ed427d8d5a --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/serializers.py @@ -0,0 +1,268 @@ +from __future__ import unicode_literals +import warnings +from mongoengine.errors import ValidationError +from rest_framework import serializers +from rest_framework import fields +import mongoengine +from mongoengine.base import BaseDocument +from django.core.paginator import Page +from django.db import models +from django.forms import widgets +from django.utils.datastructures import SortedDict +from rest_framework.compat import get_concrete_model +from .fields import ReferenceField, ListField, EmbeddedDocumentField, DynamicField + + +class MongoEngineModelSerializerOptions(serializers.ModelSerializerOptions): + """ + Meta class options for MongoEngineModelSerializer + """ + def __init__(self, meta): + super(MongoEngineModelSerializerOptions, self).__init__(meta) + self.depth = getattr(meta, 'depth', 5) + + +class MongoEngineModelSerializer(serializers.ModelSerializer): + """ + Model Serializer that supports Mongoengine + """ + _options_class = MongoEngineModelSerializerOptions + + def perform_validation(self, attrs): + """ + Rest Framework built-in validation + related model validations + """ + for field_name, field in self.fields.items(): + if field_name in self._errors: + continue + + source = field.source or field_name + if self.partial and source not in attrs: + continue + + if field_name in attrs and hasattr(field, 'model_field'): + try: + field.model_field.validate(attrs[field_name]) + except ValidationError as err: + self._errors[field_name] = str(err) + + try: + validate_method = getattr(self, 'validate_%s' % field_name, None) + if validate_method: + attrs = validate_method(attrs, source) + except serializers.ValidationError as err: + self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) + + if not self._errors: + try: + attrs = self.validate(attrs) + except serializers.ValidationError as err: + if hasattr(err, 'message_dict'): + for field_name, error_messages in err.message_dict.items(): + self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) + elif hasattr(err, 'messages'): + self._errors['non_field_errors'] = err.messages + + return attrs + + def restore_object(self, attrs, instance=None): + if instance is None: + instance = self.opts.model() + + dynamic_fields = self.get_dynamic_fields(instance) + all_fields = dict(dynamic_fields, **self.fields) + + for key, val in attrs.items(): + field = all_fields.get(key) + if not field or field.read_only: + continue + + if isinstance(field, serializers.Serializer): + many = field.many + + def _restore(field, item): + # looks like a bug, sometimes there are decerialized objects in attrs + # sometimes they are just dicts + if isinstance(item, BaseDocument): + return item + return field.from_native(item) + + if many: + val = [_restore(field, item) for item in val] + else: + val = _restore(field, val) + + key = getattr(field, 'source', None) or key + try: + setattr(instance, key, val) + except ValueError: + self._errors[key] = self.error_messages['required'] + + return instance + + def get_default_fields(self): + cls = self.opts.model + opts = get_concrete_model(cls) + fields = [] + fields += [getattr(opts, field) for field in cls._fields_ordered] + + ret = SortedDict() + + for model_field in fields: + if isinstance(model_field, mongoengine.ObjectIdField): + field = self.get_pk_field(model_field) + else: + field = self.get_field(model_field) + + if field: + field.initialize(parent=self, field_name=model_field.name) + ret[model_field.name] = field + + for field_name in self.opts.read_only_fields: + assert field_name in ret,\ + "read_only_fields on '%s' included invalid item '%s'" %\ + (self.__class__.__name__, field_name) + ret[field_name].read_only = True + + for field_name in self.opts.write_only_fields: + assert field_name in ret,\ + "write_only_fields on '%s' included invalid item '%s'" %\ + (self.__class__.__name__, field_name) + ret[field_name].write_only = True + + return ret + + def get_dynamic_fields(self, obj): + dynamic_fields = {} + if obj is not None and obj._dynamic: + for key, value in obj._dynamic_fields.items(): + dynamic_fields[key] = self.get_field(value) + return dynamic_fields + + def get_field(self, model_field): + kwargs = {} + + if model_field.__class__ in (mongoengine.ReferenceField, mongoengine.EmbeddedDocumentField, + mongoengine.ListField, mongoengine.DynamicField): + kwargs['model_field'] = model_field + kwargs['depth'] = self.opts.depth + + if not model_field.__class__ == mongoengine.ObjectIdField: + kwargs['required'] = model_field.required + + if model_field.__class__ == mongoengine.EmbeddedDocumentField: + kwargs['document_type'] = model_field.document_type + + if model_field.default: + kwargs['required'] = False + kwargs['default'] = model_field.default + + if model_field.__class__ == models.TextField: + kwargs['widget'] = widgets.Textarea + + field_mapping = { + mongoengine.FloatField: fields.FloatField, + mongoengine.IntField: fields.IntegerField, + mongoengine.DateTimeField: fields.DateTimeField, + mongoengine.EmailField: fields.EmailField, + mongoengine.URLField: fields.URLField, + mongoengine.StringField: fields.CharField, + mongoengine.BooleanField: fields.BooleanField, + mongoengine.FileField: fields.FileField, + mongoengine.ImageField: fields.ImageField, + mongoengine.ObjectIdField: fields.WritableField, + mongoengine.ReferenceField: ReferenceField, + mongoengine.ListField: ListField, + mongoengine.EmbeddedDocumentField: EmbeddedDocumentField, + mongoengine.DynamicField: DynamicField, + mongoengine.DecimalField: fields.DecimalField, + mongoengine.UUIDField: fields.CharField + } + + attribute_dict = { + mongoengine.StringField: ['max_length'], + mongoengine.DecimalField: ['min_value', 'max_value'], + mongoengine.EmailField: ['max_length'], + mongoengine.FileField: ['max_length'], + mongoengine.URLField: ['max_length'], + } + + if model_field.__class__ in attribute_dict: + attributes = attribute_dict[model_field.__class__] + for attribute in attributes: + kwargs.update({attribute: getattr(model_field, attribute)}) + + try: + return field_mapping[model_field.__class__](**kwargs) + except KeyError: + # Defaults to WritableField if not in field mapping + return fields.WritableField(**kwargs) + + def to_native(self, obj): + """ + Rest framework built-in to_native + transform_object + """ + ret = self._dict_class() + ret.fields = self._dict_class() + + #Dynamic Document Support + dynamic_fields = self.get_dynamic_fields(obj) + all_fields = self._dict_class() + all_fields.update(self.fields) + all_fields.update(dynamic_fields) + + for field_name, field in all_fields.items(): + if field.read_only and obj is None: + continue + field.initialize(parent=self, field_name=field_name) + key = self.get_field_key(field_name) + value = field.field_to_native(obj, field_name) + #Override value with transform_ methods + method = getattr(self, 'transform_%s' % field_name, None) + if callable(method): + value = method(obj, value) + if not getattr(field, 'write_only', False): + ret[key] = value + ret.fields[key] = self.augment_field(field, field_name, key, value) + + return ret + + def from_native(self, data, files=None): + self._errors = {} + + if data is not None or files is not None: + attrs = self.restore_fields(data, files) + for key in data.keys(): + if key not in attrs: + attrs[key] = data[key] + if attrs is not None: + attrs = self.perform_validation(attrs) + else: + self._errors['non_field_errors'] = ['No input provided'] + + if not self._errors: + return self.restore_object(attrs, instance=getattr(self, 'object', None)) + + @property + def data(self): + """ + Returns the serialized data on the serializer. + """ + if self._data is None: + obj = self.object + + if self.many is not None: + many = self.many + else: + many = hasattr(obj, '__iter__') and not isinstance(obj, (BaseDocument, Page, dict)) + if many: + warnings.warn('Implicit list/queryset serialization is deprecated. ' + 'Use the `many=True` flag when instantiating the serializer.', + DeprecationWarning, stacklevel=2) + + if many: + self._data = [self.to_native(item) for item in obj] + else: + self._data = self.to_native(obj) + + return self._data diff --git a/awx/lib/site-packages/rest_framework_mongoengine/tests/__init__.py b/awx/lib/site-packages/rest_framework_mongoengine/tests/__init__.py new file mode 100644 index 0000000000..0ebb5972d8 --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/tests/__init__.py @@ -0,0 +1,6 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath('../')) +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "Sample.settings") + diff --git a/awx/lib/site-packages/rest_framework_mongoengine/tests/requirements.txt b/awx/lib/site-packages/rest_framework_mongoengine/tests/requirements.txt new file mode 100644 index 0000000000..8b764864de --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/tests/requirements.txt @@ -0,0 +1,7 @@ +Django==1.6.5 +argparse==1.2.1 +djangorestframework==2.3.14 +mongoengine==0.8.7 +nose==1.3.3 +pymongo==2.7.1 +wsgiref==0.1.2 diff --git a/awx/lib/site-packages/rest_framework_mongoengine/tests/test_serializers.py b/awx/lib/site-packages/rest_framework_mongoengine/tests/test_serializers.py new file mode 100644 index 0000000000..0d73ef7c75 --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/tests/test_serializers.py @@ -0,0 +1,152 @@ +from datetime import datetime +import mongoengine as me +from unittest import TestCase +from bson import objectid + +from rest_framework_mongoengine.serializers import MongoEngineModelSerializer +from rest_framework import serializers as s + + +class Job(me.Document): + title = me.StringField() + status = me.StringField(choices=('draft', 'published')) + notes = me.StringField(required=False) + on = me.DateTimeField(default=datetime.utcnow) + weight = me.IntField(default=0) + + +class JobSerializer(MongoEngineModelSerializer): + id = s.Field() + title = s.CharField() + status = s.ChoiceField(read_only=True) + sort_weight = s.IntegerField(source='weight') + + + class Meta: + model = Job + fields = ('id', 'title','status', 'sort_weight') + + + +class TestReadonlyRestore(TestCase): + + def test_restore_object(self): + job = Job(title='original title', status='draft', notes='secure') + data = { + 'title': 'updated title ...', + 'status': 'published', # this one is read only + 'notes': 'hacked', # this field should not update + 'sort_weight': 10 # mapped to a field with differet name + } + + serializer = JobSerializer(job, data=data, partial=True) + + self.assertTrue(serializer.is_valid()) + obj = serializer.object + self.assertEqual(data['title'], obj.title) + self.assertEqual('draft', obj.status) + self.assertEqual('secure', obj.notes) + + self.assertEqual(10, obj.weight) + + + + + +# Testing restoring embedded property + +class Location(me.EmbeddedDocument): + city = me.StringField() + +# list of +class Category(me.EmbeddedDocument): + id = me.StringField() + counter = me.IntField(default=0, required=True) + + +class Secret(me.EmbeddedDocument): + key = me.StringField() + +class SomeObject(me.Document): + name = me.StringField() + loc = me.EmbeddedDocumentField('Location') + categories = me.ListField(me.EmbeddedDocumentField(Category)) + codes = me.ListField(me.EmbeddedDocumentField(Secret)) + + +class LocationSerializer(MongoEngineModelSerializer): + city = s.CharField() + + class Meta: + model = Location + +class CategorySerializer(MongoEngineModelSerializer): + id = s.CharField(max_length=24) + class Meta: + model = Category + fields = ('id',) + +class SomeObjectSerializer(MongoEngineModelSerializer): + location = LocationSerializer(source='loc') + categories = CategorySerializer(many=True, allow_add_remove=True) + + class Meta: + model = SomeObject + fields = ('name', 'location', 'categories') + + +class TestRestoreEmbedded(TestCase): + def setUp(self): + self.data = { + 'name': 'some anme', + 'location': { + 'city': 'Toronto' + }, + 'categories': [{'id': 'cat1'}, {'id': 'category_2', 'counter': 666}], + 'codes': [{'key': 'mykey1'}] + } + + def test_restore_new(self): + serializer = SomeObjectSerializer(data=self.data) + self.assertTrue(serializer.is_valid()) + obj = serializer.object + + self.assertEqual(self.data['name'], obj.name ) + self.assertEqual('Toronto', obj.loc.city ) + + self.assertEqual(2, len(obj.categories)) + self.assertEqual('category_2', obj.categories[1].id) + # counter is not listed in serializer fields, cannot be updated + self.assertEqual(0, obj.categories[1].counter) + + # codes are not listed, should not be updatable + self.assertEqual(0, len(obj.codes)) + + def test_restore_update(self): + data = self.data + instance = SomeObject( + name='original', + loc=Location(city="New York"), + categories=[Category(id='orig1', counter=777)], + codes=[Secret(key='confidential123')] + ) + serializer = SomeObjectSerializer(instance, data=data, partial=True) + + # self.assertTrue(serializer.is_valid()) + if not serializer.is_valid(): + print 'errors: %s' % serializer._errors + assert False, 'errors' + + obj = serializer.object + + self.assertEqual(data['name'], obj.name ) + self.assertEqual('Toronto', obj.loc.city ) + + # codes is not listed, should not be updatable + self.assertEqual(1, len(obj.codes[0])) + self.assertEqual('confidential123', obj.codes[0].key) # should keep original val + + self.assertEqual(2, len(obj.categories)) + self.assertEqual('category_2', obj.categories[1].id) + self.assertEqual(0, obj.categories[1].counter) + diff --git a/awx/lib/site-packages/rest_framework_mongoengine/viewsets.py b/awx/lib/site-packages/rest_framework_mongoengine/viewsets.py new file mode 100644 index 0000000000..730dfec959 --- /dev/null +++ b/awx/lib/site-packages/rest_framework_mongoengine/viewsets.py @@ -0,0 +1,34 @@ +from rest_framework import mixins +from rest_framework.viewsets import ViewSetMixin +from rest_framework_mongoengine.generics import MongoAPIView + + +class MongoGenericViewSet(ViewSetMixin, MongoAPIView): + """ + The MongoGenericViewSet class does not provide any actions by default, + but does include the base set of generic view behavior, such as + the `get_object` and `get_queryset` methods. + """ + pass + + +class ModelViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + MongoGenericViewSet): + """ + A viewset that provides default `create()`, `retrieve()`, `update()`, + `partial_update()`, `destroy()` and `list()` actions. + """ + pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + MongoGenericViewSet): + """ + A viewset that provides default `list()` and `retrieve()` actions. + """ + pass \ No newline at end of file diff --git a/awx/main/tests/__init__.py b/awx/main/tests/__init__.py index dda898b544..3b5e1fcf12 100644 --- a/awx/main/tests/__init__.py +++ b/awx/main/tests/__init__.py @@ -16,3 +16,4 @@ from awx.main.tests.schedules import * # noqa from awx.main.tests.redact import * # noqa from awx.main.tests.views import * # noqa from awx.main.tests.commands import * # noqa +from awx.main.tests.fact import * # noqa diff --git a/awx/main/tests/base.py b/awx/main/tests/base.py index 354158eff5..bb38e21d79 100644 --- a/awx/main/tests/base.py +++ b/awx/main/tests/base.py @@ -25,9 +25,6 @@ from django.contrib.auth.models import User from django.test.client import Client from django.test.utils import override_settings -# MongoEngine -from mongoengine.connection import get_db, ConnectionError - # AWX from awx.main.models import * # noqa from awx.main.backend import LDAPSettings @@ -43,15 +40,6 @@ TEST_PLAYBOOK = '''- hosts: mygroup command: test 1 = 1 ''' -class MongoDBRequired(django.test.TestCase): - def setUp(self): - # Drop mongo database - try: - self.db = get_db() - self.db.connection.drop_database(settings.MONGO_DB) - except ConnectionError: - self.skipTest('MongoDB connection failed') - class QueueTestMixin(object): def start_queue(self): self.start_redis() diff --git a/awx/main/tests/commands/cleanup_facts.py b/awx/main/tests/commands/cleanup_facts.py index 8d02310a2d..cdddce9f8f 100644 --- a/awx/main/tests/commands/cleanup_facts.py +++ b/awx/main/tests/commands/cleanup_facts.py @@ -10,7 +10,8 @@ import mock from django.core.management.base import CommandError # AWX -from awx.main.tests.base import BaseTest, MongoDBRequired +from awx.main.tests.base import BaseTest +from awx.fact.tests.base import MongoDBRequired from awx.main.tests.commands.base import BaseCommandMixin from awx.main.management.commands.cleanup_facts import Command, CleanupFacts from awx.fact.models.fact import * # noqa diff --git a/awx/main/tests/commands/run_fact_cache_receiver.py b/awx/main/tests/commands/run_fact_cache_receiver.py index d0f2d17b49..13f150a9b5 100644 --- a/awx/main/tests/commands/run_fact_cache_receiver.py +++ b/awx/main/tests/commands/run_fact_cache_receiver.py @@ -10,7 +10,8 @@ from copy import deepcopy from mock import MagicMock # AWX -from awx.main.tests.base import BaseTest, MongoDBRequired +from awx.main.tests.base import BaseTest +from awx.fact.tests.base import MongoDBRequired from awx.main.tests.commands.base import BaseCommandMixin from awx.main.management.commands.run_fact_cache_receiver import FactCacheReceiver from awx.fact.models.fact import * # noqa diff --git a/awx/main/tests/fact/__init__.py b/awx/main/tests/fact/__init__.py new file mode 100644 index 0000000000..234499d6e9 --- /dev/null +++ b/awx/main/tests/fact/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2015 Ansible, Inc. +# All Rights Reserved + +from __future__ import absolute_import + +from .fact_api import * # noqa diff --git a/awx/main/tests/fact/fact_api.py b/awx/main/tests/fact/fact_api.py new file mode 100644 index 0000000000..55156d1436 --- /dev/null +++ b/awx/main/tests/fact/fact_api.py @@ -0,0 +1,236 @@ +# Copyright (c) 2015 Ansible, Inc. +# All Rights Reserved + +# Python + +# Django +from django.core.urlresolvers import reverse + +# AWX +from awx.main.utils import timestamp_apiformat +from awx.main.models import * # noqa +from awx.main.tests.base import BaseLiveServerTest +from awx.fact.models import * # noqa +from awx.fact.tests.base import BaseFactTestMixin, FactScanBuilder, TEST_FACT_ANSIBLE, TEST_FACT_PACKAGES, TEST_FACT_SERVICES +from awx.main.utils import build_url + +__all__ = ['FactVersionApiTest', 'FactViewApiTest', 'SingleFactApiTest',] + +class FactApiBaseTest(BaseLiveServerTest, BaseFactTestMixin): + def setUp(self): + super(FactApiBaseTest, self).setUp() + self.setup_instances() + self.setup_users() + self.organization = self.make_organization(self.super_django_user) + self.organization.admins.add(self.normal_django_user) + self.inventory = self.organization.inventories.create(name='test-inventory', description='description for test-inventory') + self.host = self.inventory.hosts.create(name='host.example.com') + self.host2 = self.inventory.hosts.create(name='host2.example.com') + self.host3 = self.inventory.hosts.create(name='host3.example.com') + + def setup_facts(self, scan_count): + self.builder = FactScanBuilder() + self.builder.add_fact('ansible', TEST_FACT_ANSIBLE) + self.builder.add_fact('packages', TEST_FACT_PACKAGES) + self.builder.add_fact('services', TEST_FACT_SERVICES) + self.builder.add_hostname('host.example.com') + self.builder.add_hostname('host2.example.com') + self.builder.add_hostname('host3.example.com') + self.builder.build(scan_count=scan_count, host_count=3) + + self.fact_host = FactHost.objects.get(hostname=self.host.name) + +class FactVersionApiTest(FactApiBaseTest): + def check_equal(self, fact_versions, results): + def find(element, set1): + for e in set1: + if all([ e.get(field) == element.get(field) for field in element.keys()]): + return e + return None + + self.assertEqual(len(results), len(fact_versions)) + for v in fact_versions: + v_dict = { + 'timestamp': timestamp_apiformat(v.timestamp), + 'module': v.module + } + e = find(v_dict, results) + self.assertIsNotNone(e, "%s not found in %s" % (v_dict, results)) + + def get_list(self, fact_versions, params=None): + url = build_url('api:host_fact_versions_list', args=(self.host.pk,), get=params) + with self.current_user(self.super_django_user): + response = self.get(url, expect=200) + + self.check_equal(fact_versions, response['results']) + return response + + def test_permission_list(self): + url = reverse('api:host_fact_versions_list', args=(self.host.pk,)) + with self.current_user('admin'): + self.get(url, expect=200) + with self.current_user('normal'): + self.get(url, expect=200) + with self.current_user('other'): + self.get(url, expect=403) + with self.current_user('nobody'): + self.get(url, expect=403) + with self.current_user(None): + self.get(url, expect=401) + + def test_list_empty(self): + url = reverse('api:host_fact_versions_list', args=(self.host.pk,)) + with self.current_user(self.super_django_user): + response = self.get(url, expect=200) + self.assertIn('results', response) + self.assertIsInstance(response['results'], list) + self.assertEqual(len(response['results']), 0) + + def test_list_related_fact_view(self): + self.setup_facts(2) + url = reverse('api:host_fact_versions_list', args=(self.host.pk,)) + with self.current_user(self.super_django_user): + response = self.get(url, expect=200) + for entry in response['results']: + self.assertIn('fact_view', entry['related']) + r = self.get(entry['related']['fact_view'], expect=200) + + def test_list(self): + self.setup_facts(2) + self.get_list(FactVersion.objects.filter(host=self.fact_host)) + + def test_list_module(self): + self.setup_facts(10) + self.get_list(FactVersion.objects.filter(host=self.fact_host, module='packages'), dict(module='packages')) + + def test_list_time_from(self): + self.setup_facts(10) + + params = { + 'from': timestamp_apiformat(self.builder.get_timestamp(1)), + } + # 'to': timestamp_apiformat(self.builder.get_timestamp(3)) + fact_versions = FactVersion.objects.filter(host=self.fact_host, timestamp__gt=params['from']) + self.get_list(fact_versions, params) + + def test_list_time_to(self): + self.setup_facts(10) + + params = { + 'to': timestamp_apiformat(self.builder.get_timestamp(3)) + } + fact_versions = FactVersion.objects.filter(host=self.fact_host, timestamp__lte=params['to']) + self.get_list(fact_versions, params) + + def test_list_time_from_to(self): + self.setup_facts(10) + + params = { + 'from': timestamp_apiformat(self.builder.get_timestamp(1)), + 'to': timestamp_apiformat(self.builder.get_timestamp(3)) + } + fact_versions = FactVersion.objects.filter(host=self.fact_host, timestamp__gt=params['from'], timestamp__lte=params['to']) + self.get_list(fact_versions, params) + + +class FactViewApiTest(FactApiBaseTest): + def check_equal(self, fact_obj, results): + fact_dict = { + 'timestamp': timestamp_apiformat(fact_obj.timestamp), + 'module': fact_obj.module, + 'host': { + 'hostname': fact_obj.host.hostname, + 'id': str(fact_obj.host.id) + }, + 'fact': fact_obj.fact + } + self.assertEqual(fact_dict, results) + + def test_permission_view(self): + url = reverse('api:host_fact_compare_view', args=(self.host.pk,)) + with self.current_user('admin'): + self.get(url, expect=200) + with self.current_user('normal'): + self.get(url, expect=200) + with self.current_user('other'): + self.get(url, expect=403) + with self.current_user('nobody'): + self.get(url, expect=403) + with self.current_user(None): + self.get(url, expect=401) + + def get_fact(self, fact_obj, params=None): + url = build_url('api:host_fact_compare_view', args=(self.host.pk,), get=params) + with self.current_user(self.super_django_user): + response = self.get(url, expect=200) + + self.check_equal(fact_obj, response) + + def test_view(self): + self.setup_facts(2) + self.get_fact(Fact.objects.filter(host=self.fact_host, module='ansible').order_by('-timestamp')[0]) + + def test_view_module_filter(self): + self.setup_facts(2) + self.get_fact(Fact.objects.filter(host=self.fact_host, module='services').order_by('-timestamp')[0], dict(module='services')) + + def test_view_time_filter(self): + self.setup_facts(6) + ts = self.builder.get_timestamp(3) + self.get_fact(Fact.objects.filter(host=self.fact_host, module='ansible', timestamp__lte=ts).order_by('-timestamp')[0], + dict(datetime=ts)) + +class SingleFactApiTest(FactApiBaseTest): + def setUp(self): + super(SingleFactApiTest, self).setUp() + + self.group = self.inventory.groups.create(name='test-group') + self.group.hosts.add(self.host, self.host2, self.host3) + + def test_permission_list(self): + url = reverse('api:host_fact_versions_list', args=(self.host.pk,)) + with self.current_user('admin'): + self.get(url, expect=200) + with self.current_user('normal'): + self.get(url, expect=200) + with self.current_user('other'): + self.get(url, expect=403) + with self.current_user('nobody'): + self.get(url, expect=403) + with self.current_user(None): + self.get(url, expect=401) + + def _test_related(self, url): + with self.current_user(self.super_django_user): + response = self.get(url, expect=200) + self.assertTrue(len(response['results']) > 0) + for entry in response['results']: + self.assertIn('single_fact', entry['related']) + # Requires fields + r = self.get(entry['related']['single_fact'], expect=400) + + def test_related_host_list(self): + self.setup_facts(2) + self._test_related(reverse('api:host_list')) + + def test_related_group_list(self): + self.setup_facts(2) + self._test_related(reverse('api:group_list')) + + def test_related_inventory_list(self): + self.setup_facts(2) + self._test_related(reverse('api:inventory_list')) + + def test_params(self): + self.setup_facts(2) + params = { + 'module': 'packages', + 'fact_key': 'name', + 'fact_value': 'acpid', + } + url = build_url('api:inventory_single_fact_view', args=(self.inventory.pk,), get=params) + with self.current_user(self.super_django_user): + response = self.get(url, expect=200) + self.assertEqual(len(response['results']), 3) + for entry in response['results']: + self.assertEqual(entry['fact'][0]['name'], 'acpid') diff --git a/awx/main/utils.py b/awx/main/utils.py index be8c44e890..f980722804 100644 --- a/awx/main/utils.py +++ b/awx/main/utils.py @@ -19,7 +19,7 @@ import tempfile # Django REST Framework from rest_framework.exceptions import ParseError, PermissionDenied from django.utils.encoding import smart_str - +from django.core.urlresolvers import reverse # PyCrypto from Crypto.Cipher import AES @@ -487,3 +487,16 @@ def get_pk_from_dict(_dict, key): return int(_dict[key]) except (TypeError, KeyError, ValueError): return None + +def build_url(*args, **kwargs): + get = kwargs.pop('get', {}) + url = reverse(*args, **kwargs) + if get: + url += '?' + urllib.urlencode(get) + return url + +def timestamp_apiformat(timestamp): + timestamp = timestamp.isoformat() + if timestamp.endswith('+00:00'): + timestamp = timestamp[:-6] + 'Z' + return timestamp diff --git a/awx/playbooks/scan_facts.yml b/awx/playbooks/scan_facts.yml index cd33098771..ffbe4512fd 100644 --- a/awx/playbooks/scan_facts.yml +++ b/awx/playbooks/scan_facts.yml @@ -4,6 +4,7 @@ scan_use_recursive: false tasks: - scan_packages: + - scan_services: - scan_files: path: '{{ scan_file_path }}' get_checksum: '{{ scan_use_checksum }}' diff --git a/awx/plugins/library/scan_packages.py b/awx/plugins/library/scan_packages.py index 8fb2a68008..8077cfe45d 100755 --- a/awx/plugins/library/scan_packages.py +++ b/awx/plugins/library/scan_packages.py @@ -3,10 +3,47 @@ import os from ansible.module_utils.basic import * # noqa +DOCUMENTATION = ''' +--- +module: scan_packages +short_description: Return installed packages information as fact data +description: + - Return information about installed packages as fact data +version_added: "1.9" +options: +requirements: [ ] +author: Matthew Jones +''' + +EXAMPLES = ''' +# Example fact output: +# host | success >> { +# "ansible_facts": { +# "services": [ +# { +# "source": "apt", +# "version": "1.0.6-5", +# "architecture": "amd64", +# "name": "libbz2-1.0" +# }, +# { +# "source": "apt", +# "version": "2.7.1-4ubuntu1", +# "architecture": "amd64", +# "name": "patch" +# }, +# { +# "source": "apt", +# "version": "4.8.2-19ubuntu1", +# "architecture": "amd64", +# "name": "gcc-4.8-base" +# }, ... ] } } +''' + def rpm_package_list(): import rpm trans_set = rpm.TransactionSet() - installed_packages = {} + installed_packages = [] for package in trans_set.dbMatch(): package_details = dict(name=package[rpm.RPMTAG_NAME], version=package[rpm.RPMTAG_VERSION], @@ -14,16 +51,13 @@ def rpm_package_list(): epoch=package[rpm.RPMTAG_EPOCH], arch=package[rpm.RPMTAG_ARCH], source='rpm') - if package['name'] not in installed_packages: - installed_packages[package['name']] = [package_details] - else: - installed_packages[package['name']].append(package_details) + installed_packages.append(package_details) return installed_packages def deb_package_list(): import apt apt_cache = apt.Cache() - installed_packages = {} + installed_packages = [] apt_installed_packages = [pk for pk in apt_cache.keys() if apt_cache[pk].is_installed] for package in apt_installed_packages: ac_pkg = apt_cache[package].installed @@ -31,7 +65,7 @@ def deb_package_list(): version=ac_pkg.version, architecture=ac_pkg.architecture, source='apt') - installed_packages[package] = [package_details] + installed_packages.append(package_details) return installed_packages def main():