pgreco / rpms / ipa

Forked from forks/areguera/rpms/ipa 4 years ago
Clone
Blob Blame History Raw
From be32fd1d727fc8398dd51fa0fd3f404ef451281a Mon Sep 17 00:00:00 2001
From: Jan Cholasta <jcholast@redhat.com>
Date: Mon, 1 Aug 2016 09:53:39 +0200
Subject: [PATCH] cert: speed up cert-find

Use issuer+serial rather than raw DER blob to identify certificates in
cert-find's intermediate result.

Restructure the code to make it (hopefully) easier to follow.

https://fedorahosted.org/freeipa/ticket/6098

Reviewed-By: Martin Basti <mbasti@redhat.com>
Reviewed-By: Pavel Vomacka <pvomacka@redhat.com>
---
 ipaserver/plugins/cert.py | 398 +++++++++++++++++++++++++---------------------
 1 file changed, 216 insertions(+), 182 deletions(-)

diff --git a/ipaserver/plugins/cert.py b/ipaserver/plugins/cert.py
index 06041d3083565e8d093b610473d6083111d406d2..47dccf15a4010f2766642aedd2cc16e0a1eb1dd4 100644
--- a/ipaserver/plugins/cert.py
+++ b/ipaserver/plugins/cert.py
@@ -21,6 +21,7 @@
 
 import base64
 import binascii
+import collections
 import datetime
 import os
 
@@ -295,18 +296,24 @@ class BaseCertObject(Object):
         ),
     )
 
-    def _parse(self, obj):
-        cert = x509.load_certificate(obj['certificate'])
-        obj['subject'] = DN(unicode(cert.subject))
-        obj['issuer'] = DN(unicode(cert.issuer))
-        obj['valid_not_before'] = unicode(cert.valid_not_before_str)
-        obj['valid_not_after'] = unicode(cert.valid_not_after_str)
-        obj['md5_fingerprint'] = unicode(
-            nss.data_to_hex(nss.md5_digest(cert.der_data), 64)[0])
-        obj['sha1_fingerprint'] = unicode(
-            nss.data_to_hex(nss.sha1_digest(cert.der_data), 64)[0])
-        obj['serial_number'] = cert.serial_number
-        obj['serial_number_hex'] = u'0x%X' % cert.serial_number
+    def _parse(self, obj, full=True):
+        cert = obj.get('certificate')
+        if cert is not None:
+            cert = x509.load_certificate(cert)
+            obj['subject'] = DN(unicode(cert.subject))
+            obj['issuer'] = DN(unicode(cert.issuer))
+            obj['serial_number'] = cert.serial_number
+            if full:
+                obj['valid_not_before'] = unicode(cert.valid_not_before_str)
+                obj['valid_not_after'] = unicode(cert.valid_not_after_str)
+                obj['md5_fingerprint'] = unicode(
+                    nss.data_to_hex(nss.md5_digest(cert.der_data), 64)[0])
+                obj['sha1_fingerprint'] = unicode(
+                    nss.data_to_hex(nss.sha1_digest(cert.der_data), 64)[0])
+
+        serial_number = obj.get('serial_number')
+        if serial_number is not None:
+            obj['serial_number_hex'] = u'0x%X' % serial_number
 
 
 class BaseCertMethod(Method):
@@ -691,10 +698,14 @@ class cert(BaseCertObject):
             yield self.api.Object[name]
 
     def _fill_owners(self, obj):
+        dns = obj.pop('owner', None)
+        if dns is None:
+            return
+
         for owner in self._owners():
             container_dn = DN(owner.container_dn, self.api.env.basedn)
             name = 'owner_' + owner.name
-            for dn in obj['owner']:
+            for dn in dns:
                 if dn.endswith(container_dn, 1):
                     value = owner.get_primary_key_from_dn(dn)
                     obj.setdefault(name, []).append(value)
@@ -776,9 +787,7 @@ class cert_show(Retrieve, CertMethod, VirtualCommand):
             result['certificate'] = result['certificate'].replace('\r\n', '')
             self.obj._parse(result)
             result['revoked'] = ('revocation_reason' in result)
-            if 'owner' in result:
-                self.obj._fill_owners(result)
-                del result['owner']
+            self.obj._fill_owners(result)
 
         if hostname:
             # If we have a hostname we want to verify that the subject
@@ -984,36 +993,171 @@ class cert_find(Search, CertMethod):
                 label=owner.object_name,
             )
 
-    def execute(self, criteria=None, all=False, raw=False, pkey_only=False,
-                no_members=True, timelimit=None, sizelimit=None, **options):
-        ca_options = {'cacn',
-                      'revocation_reason',
-                      'issuer',
-                      'subject',
-                      'min_serial_number', 'max_serial_number',
-                      'exactly',
-                      'validnotafter_from', 'validnotafter_to',
-                      'validnotbefore_from', 'validnotbefore_to',
-                      'issuedon_from', 'issuedon_to',
-                      'revokedon_from', 'revokedon_to'}
-        ldap_options = {prefix + owner.name
-                        for owner in self.obj._owners()
-                        for prefix in ('', 'no_')}
-        has_ca_options = (
-            any(name in options for name in ca_options - {'exactly'}) or
-            options['exactly'])
-        has_ldap_options = any(name in options for name in ldap_options)
-        has_cert_option = 'certificate' in options
+    def _get_cert_key(self, cert):
+        nss_cert = x509.load_certificate(cert, x509.DER)
+
+        return (DN(unicode(nss_cert.issuer)), nss_cert.serial_number)
+
+    def _get_cert_obj(self, cert, all, raw, pkey_only):
+        obj = {'certificate': unicode(base64.b64encode(cert))}
+
+        full = not pkey_only and all
+        if not raw:
+            self.obj._parse(obj, full)
+        if not full:
+            del obj['certificate']
+
+        return obj
+
+    def _cert_search(self, all, raw, pkey_only, **options):
+        result = collections.OrderedDict()
+
+        try:
+            cert = options['certificate']
+        except KeyError:
+            return result, False, False
+
+        key = self._get_cert_key(cert)
+
+        result[key] = self._get_cert_obj(cert, all, raw, pkey_only)
+
+        return result, False, True
+
+    def _ca_search(self, all, raw, pkey_only, sizelimit, exactly, **options):
+        ra_options = {}
+        for name in ('revocation_reason',
+                     'issuer',
+                     'subject',
+                     'min_serial_number', 'max_serial_number',
+                     'validnotafter_from', 'validnotafter_to',
+                     'validnotbefore_from', 'validnotbefore_to',
+                     'issuedon_from', 'issuedon_to',
+                     'revokedon_from', 'revokedon_to'):
+            try:
+                value = options[name]
+            except KeyError:
+                continue
+            if isinstance(value, datetime.datetime):
+                value = value.strftime(PKIDATE_FORMAT)
+            elif isinstance(value, DN):
+                value = unicode(value)
+            ra_options[name] = value
+        if sizelimit:
+            ra_options['sizelimit'] = sizelimit
+        if exactly:
+            ra_options['exactly'] = True
+
+        result = collections.OrderedDict()
+        complete = bool(ra_options)
 
         try:
             ca_enabled_check()
         except errors.NotFound:
-            if has_ca_options:
+            if ra_options:
                 raise
-            ca_enabled = False
+            return result, False, complete
+
+        ra = self.api.Backend.ra
+        for ra_obj in ra.find(ra_options):
+            issuer = DN(ra_obj['issuer'])
+            serial_number = ra_obj['serial_number']
+
+            if pkey_only:
+                obj = {'serial_number': serial_number}
+            else:
+                obj = ra_obj
+                obj['issuer'] = issuer
+                obj['subject'] = DN(ra_obj['subject'])
+                del obj['serial_number_hex']
+
+                if all:
+                    ra_obj = ra.get_certificate(str(serial_number))
+                    if not raw:
+                        obj['certificate'] = (
+                            ra_obj['certificate'].replace('\r\n', ''))
+                        self.obj._parse(obj)
+
+            result[issuer, serial_number] = obj
+
+        return result, False, complete
+
+    def _ldap_search(self, all, raw, pkey_only, no_members, timelimit,
+                     sizelimit, **options):
+        ldap = self.api.Backend.ldap2
+
+        filters = []
+        for owner in self.obj._owners():
+            for prefix, rule in (('', ldap.MATCH_ALL),
+                                 ('no_', ldap.MATCH_NONE)):
+                try:
+                    value = options[prefix + owner.name]
+                except KeyError:
+                    continue
+
+                filter = ldap.make_filter_from_attr(
+                    'objectclass',
+                    owner.object_class,
+                    ldap.MATCH_ALL)
+                if filter not in filters:
+                    filters.append(filter)
+
+                filter = ldap.make_filter_from_attr(
+                    owner.primary_key.name,
+                    value,
+                    rule)
+                filters.append(filter)
+
+        cert = options.get('certificate')
+        if cert is not None:
+            filter = ldap.make_filter_from_attr('usercertificate', cert)
+            filters.append(filter)
+
+        result = collections.OrderedDict()
+        complete = bool(filters)
+
+        if cert is None:
+            filter = '(usercertificate=*)'
+            filters.append(filter)
+
+        filter = ldap.combine_filters(filters, ldap.MATCH_ALL)
+        try:
+            entries, truncated = ldap.find_entries(
+                base_dn=self.api.env.basedn,
+                filter=filter,
+                attrs_list=['usercertificate'],
+                time_limit=timelimit,
+                size_limit=sizelimit,
+            )
+        except errors.EmptyResult:
+            entries = []
+            truncated = False
         else:
-            ca_enabled = True
+            truncated = bool(truncated)
+
+        for entry in entries:
+            for attr in ('usercertificate', 'usercertificate;binary'):
+                for cert in entry.get(attr, []):
+                    key = self._get_cert_key(cert)
+
+                    try:
+                        obj = result[key]
+                    except KeyError:
+                        obj = self._get_cert_obj(cert, all, raw, pkey_only)
+                        result[key] = obj
 
+                    if not pkey_only and (all or not no_members):
+                        owners = obj.setdefault('owner', [])
+                        if entry.dn not in owners:
+                            owners.append(entry.dn)
+
+        if not raw:
+            for obj in six.itervalues(result):
+                self.obj._fill_owners(obj)
+
+        return result, truncated, complete
+
+    def execute(self, criteria=None, all=False, raw=False, pkey_only=False,
+                no_members=True, timelimit=None, sizelimit=None, **options):
         if 'cacn' in options:
             ca_obj = api.Command.ca_show(options['cacn'])['result']
             ca_sdn = unicode(ca_obj['ipacasubjectdn'][0])
@@ -1028,153 +1172,43 @@ class cert_find(Search, CertMethod):
         if criteria is not None:
             return dict(result=[], count=0, truncated=False)
 
-        obj_seq = []
-        obj_dict = {}
+        result = collections.OrderedDict()
         truncated = False
-
-        if has_cert_option:
-            cert = options['certificate']
-            obj = {'certificate': unicode(base64.b64encode(cert))}
-            obj_seq.append(obj)
-            obj_dict[cert] = obj
-
-        if ca_enabled:
-            ra_options = {}
-            for name, value in options.items():
-                if name not in ca_options:
-                    continue
-                if isinstance(value, datetime.datetime):
-                    value = value.strftime(PKIDATE_FORMAT)
-                elif isinstance(value, DN):
-                    value = unicode(value)
-                ra_options[name] = value
-            if sizelimit is not None:
-                if sizelimit != 0:
-                    ra_options['sizelimit'] = sizelimit
-                sizelimit = 0
-                has_ca_options = True
-
-            for ra_obj in self.Backend.ra.find(ra_options):
-                obj = {}
-                if ((not pkey_only and all) or
-                        not no_members or
-                        not has_ca_options or
-                        has_ldap_options or
-                        has_cert_option):
-                    ra_obj.update(
-                        self.Backend.ra.get_certificate(
-                            str(ra_obj['serial_number'])))
-                    cert = base64.b64decode(ra_obj['certificate'])
-                    try:
-                        obj = obj_dict[cert]
-                    except KeyError:
-                        if has_cert_option:
-                            continue
-                        obj = {}
-                        obj_seq.append(obj)
-                        obj_dict[cert] = obj
+        complete = False
+
+        for sub_search in (self._cert_search,
+                           self._ca_search,
+                           self._ldap_search):
+            sub_result, sub_truncated, sub_complete = sub_search(
+                all=all,
+                raw=raw,
+                pkey_only=pkey_only,
+                no_members=no_members,
+                timelimit=timelimit,
+                sizelimit=sizelimit,
+                **options)
+
+            if sub_complete:
+                sizelimit = None
+
+                for key in tuple(result):
+                    if key not in sub_result:
+                        del result[key]
+
+            for key, sub_obj in six.iteritems(sub_result):
+                try:
+                    obj = result[key]
+                except KeyError:
+                    if complete:
+                        continue
+                    result[key] = sub_obj
                 else:
-                    obj_seq.append(obj)
-                obj.update(ra_obj)
-
-        if ((not pkey_only and all) or
-                not no_members or
-                not has_ca_options or
-                has_ldap_options or
-                has_cert_option):
-            ldap = self.api.Backend.ldap2
+                    obj.update(sub_obj)
 
-            filters = []
-            if 'certificate' in options:
-                cert_filter = ldap.make_filter_from_attr(
-                    'usercertificate', options['certificate'])
-            else:
-                cert_filter = '(usercertificate=*)'
-            filters.append(cert_filter)
-            for owner in self.obj._owners():
-                oc_filter = ldap.make_filter_from_attr(
-                    'objectclass', owner.object_class, ldap.MATCH_ALL)
-                for prefix, rule in (('', ldap.MATCH_ALL),
-                                     ('no_', ldap.MATCH_NONE)):
-                    value = options.get(prefix + owner.name)
-                    if value is None:
-                        continue
-                    pkey_filter = ldap.make_filter_from_attr(
-                        owner.primary_key.name, value, rule)
-                    filters.append(oc_filter)
-                    filters.append(pkey_filter)
-            filter = ldap.combine_filters(filters, ldap.MATCH_ALL)
+            truncated = truncated or sub_truncated
+            complete = complete or sub_complete
 
-            try:
-                entries, truncated = ldap.find_entries(
-                    base_dn=self.api.env.basedn,
-                    filter=filter,
-                    attrs_list=['usercertificate'],
-                    time_limit=timelimit,
-                    size_limit=sizelimit,
-                )
-            except errors.EmptyResult:
-                entries, truncated = [], False
-            for entry in entries:
-                seen = set()
-                for attr in ('usercertificate', 'usercertificate;binary'):
-                    for cert in entry.get(attr, []):
-                        if cert in seen:
-                            continue
-                        seen.add(cert)
-                        try:
-                            obj = obj_dict[cert]
-                        except KeyError:
-                            if has_ca_options or has_cert_option:
-                                continue
-                            obj = {
-                                'certificate': unicode(base64.b64encode(cert))}
-                            obj_seq.append(obj)
-                            obj_dict[cert] = obj
-                        obj.setdefault('owner', []).append(entry.dn)
-
-        result = []
-        for obj in obj_seq:
-            if has_ldap_options and 'owner' not in obj:
-                continue
-            if not pkey_only:
-                if not raw:
-                    if 'certificate' in obj:
-                        obj['certificate'] = (
-                            obj['certificate'].replace('\r\n', ''))
-                        self.obj._parse(obj)
-                        if not all:
-                            del obj['certificate']
-                            del obj['valid_not_before']
-                            del obj['valid_not_after']
-                            del obj['md5_fingerprint']
-                            del obj['sha1_fingerprint']
-                    if 'subject' in obj:
-                        obj['subject'] = DN(obj['subject'])
-                    if 'issuer' in obj:
-                        obj['issuer'] = DN(obj['issuer'])
-                    if 'status' in obj:
-                        obj['revoked'] = (
-                            obj['status'] in (u'REVOKED', u'REVOKED_EXPIRED'))
-                    if 'owner' in obj:
-                        if all or not no_members:
-                            self.obj._fill_owners(obj)
-                        del obj['owner']
-                else:
-                    if 'certificate' in obj:
-                        if not all:
-                            del obj['certificate']
-                    if 'owner' in obj:
-                        if not all and no_members:
-                            del obj['owner']
-            else:
-                if 'serial_number' in obj:
-                    serial_number = obj['serial_number']
-                    obj.clear()
-                    obj['serial_number'] = serial_number
-                else:
-                    obj.clear()
-            result.append(obj)
+        result = list(six.itervalues(result))
 
         ret = dict(
             result=result
-- 
2.7.4