pgreco / rpms / ipa

Forked from forks/areguera/rpms/ipa 4 years ago
Clone
Blob Blame History Raw
From beff42632d1db674802c817afd49a3ac8bcd8fb6 Mon Sep 17 00:00:00 2001
From: David Kupka <dkupka@redhat.com>
Date: Wed, 27 Jul 2016 10:46:40 +0200
Subject: [PATCH] schema: Speed up schema cache

Check presence of schema in cache (and download it if necessary) on
__init__ instead of with each __getitem__ call. Prefill internal
dictionary with empty record for each command to be able to quickly
determine if requested command exist in schema or not. Rest of schema
data are read from cache on first attempt to retrive them.

https://fedorahosted.org/freeipa/ticket/6048
https://fedorahosted.org/freeipa/ticket/6069

Reviewed-By: Jan Cholasta <jcholast@redhat.com>
---
 ipaclient/remote_plugins/schema.py | 301 ++++++++++++++++++++++---------------
 1 file changed, 177 insertions(+), 124 deletions(-)

diff --git a/ipaclient/remote_plugins/schema.py b/ipaclient/remote_plugins/schema.py
index 0301e54127dc236ebc14e1409484626f1427800d..d039fb41991c26a9c7b7f76f6959668efb677586 100644
--- a/ipaclient/remote_plugins/schema.py
+++ b/ipaclient/remote_plugins/schema.py
@@ -5,10 +5,8 @@
 import collections
 import errno
 import fcntl
-import glob
 import json
 import os
-import re
 import sys
 import time
 import types
@@ -65,8 +63,6 @@ USER_CACHE_PATH = (
         '.cache'
     )
 )
-SCHEMA_DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'schema')
-SERVERS_DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers')
 
 logger = log_mgr.get_logger(__name__)
 
@@ -274,15 +270,6 @@ class _SchemaObjectPlugin(_SchemaPlugin):
     schema_key = 'classes'
 
 
-def _ensure_dir_created(d):
-    try:
-        os.makedirs(d)
-    except OSError as e:
-        if e.errno != errno.EEXIST:
-            raise RuntimeError("Unable to create cache directory: {}"
-                               "".format(e))
-
-
 class _LockedZipFile(zipfile.ZipFile):
     """ Add locking to zipfile.ZipFile
     Shared lock is used with read mode, exclusive with write mode.
@@ -308,7 +295,10 @@ class _SchemaNameSpace(collections.Mapping):
         self._schema = schema
 
     def __getitem__(self, key):
-        return self._schema.read_namespace_member(self.name, key)
+        try:
+            return self._schema.read_namespace_member(self.name, key)
+        except KeyError:
+            raise KeyError(key)
 
     def __iter__(self):
         for key in self._schema.iter_namespace(self.name):
@@ -322,6 +312,62 @@ class NotAvailable(Exception):
     pass
 
 
+class ServerInfo(collections.MutableMapping):
+    _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers')
+
+    def __init__(self, api):
+        hostname = DNSName(api.env.server).ToASCII()
+        self._path = os.path.join(self._DIR, hostname)
+        self._dict = {}
+        self._dirty = False
+
+        self._read()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *_exc_info):
+        if self._dirty:
+            self._write()
+
+    def _read(self):
+        try:
+            with open(self._path, 'r') as sc:
+                self._dict = json.load(sc)
+        except EnvironmentError as e:
+            if e.errno != errno.ENOENT:
+                logger.warning('Failed to read server info: {}'.format(e))
+
+    def _write(self):
+        try:
+            try:
+                os.makedirs(self._DIR)
+            except EnvironmentError as e:
+                if e.errno != errno.EEXIST:
+                    raise
+            with open(self._path, 'w') as sc:
+                json.dump(self._dict, sc)
+        except EnvironmentError as e:
+            logger.warning('Failed to write server info: {}'.format(e))
+
+    def __getitem__(self, key):
+        return self._dict[key]
+
+    def __setitem__(self, key, value):
+        self._dirty = key not in self._dict or self._dict[key] != value
+        self._dict[key] = value
+
+    def __delitem__(self, key):
+        del self._dict[key]
+        self._dirty = True
+
+    def __iter__(self):
+        return iter(self._dict)
+
+    def __len__(self):
+        return len(self._dict)
+
+
 class Schema(object):
     """
     Store and provide schema for commands and topics
@@ -342,38 +388,76 @@ class Schema(object):
     u'Ping the remote IPA server to ...'
 
     """
-    schema_path_template = os.path.join(SCHEMA_DIR, '{}')
-    servers_path_template = os.path.join(SERVERS_DIR, '{}')
-    ns_member_pattern_template = '^{}/(?P<name>.+)$'
-    ns_member_path_template = '{}/{}'
     namespaces = {'classes', 'commands', 'topics'}
     schema_info_path = 'schema'
+    _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'schema')
 
-    @classmethod
-    def _list(cls):
-        for f in glob.glob(cls.schema_path_template.format('*')):
-            yield os.path.splitext(os.path.basename(f))[0]
+    def __init__(self, api, server_info, client):
+        self._dict = {}
+        self._namespaces = {}
+        self._help = None
 
-    @classmethod
-    def _in_cache(cls, fingeprint):
-        return os.path.exists(cls.schema_path_template.format(fingeprint))
+        for ns in self.namespaces:
+            self._dict[ns] = {}
+            self._namespaces[ns] = _SchemaNameSpace(self, ns)
 
-    def __init__(self, api, client):
-        self._api = api
-        self._client = client
-        self._dict = {}
+        is_known = False
+        if not api.env.force_schema_check:
+            try:
+                self._fingerprint = server_info['fingerprint']
+                self._expiration = server_info['expiration']
+            except KeyError:
+                pass
+            else:
+                is_known = True
+
+        if is_known:
+            try:
+                self._read_schema()
+            except Exception:
+                pass
+            else:
+                return
 
-    def _open_server_info(self, hostname, mode):
-        encoded_hostname = DNSName(hostname).ToASCII()
-        path = self.servers_path_template.format(encoded_hostname)
-        return open(path, mode)
+        try:
+            self._fetch(client)
+        except NotAvailable:
+            raise
+        else:
+            self._write_schema()
+        finally:
+            try:
+                server_info['fingerprint'] = self._fingerprint
+                server_info['expiration'] = self._expiration
+            except AttributeError:
+                pass
 
-    def _get_schema(self):
-        client = self._client
+    def _open_schema(self, filename, mode):
+        path = os.path.join(self._DIR, filename)
+        return _LockedZipFile(path, mode)
+
+    def _get_schema_fingerprint(self, schema):
+        schema_info = json.loads(schema.read(self.schema_info_path))
+        return schema_info['fingerprint']
+
+    def _fetch(self, client):
         if not client.isconnected():
             client.connect(verbose=False)
 
-        fps = [unicode(f) for f in Schema._list()]
+        fps = []
+        try:
+            files = os.listdir(self._DIR)
+        except EnvironmentError:
+            pass
+        else:
+            for filename in files:
+                try:
+                    with self._open_schema(filename, 'r') as schema:
+                        fps.append(
+                            unicode(self._get_schema_fingerprint(schema)))
+                except Exception:
+                    continue
+
         kwargs = {u'version': u'2.170'}
         if fps:
             kwargs[u'known_fingerprints'] = fps
@@ -386,110 +470,80 @@ class Schema(object):
             ttl = e.ttl
         else:
             fp = schema['fingerprint']
-            ttl = schema['ttl']
-            self._store(fp, schema)
-        finally:
-            client.disconnect()
+            ttl = schema.pop('ttl', 0)
 
-        exp = ttl + time.time()
-        return (fp, exp)
+            for key, value in schema.items():
+                if key in self.namespaces:
+                    value = {m['full_name']: m for m in value}
+                self._dict[key] = value
 
-    def _ensure_cached(self):
-        no_info = False
-        try:
-            # pylint: disable=access-member-before-definition
-            fp = self._server_schema_fingerprint
-            exp = self._server_schema_expiration
-        except AttributeError:
-            try:
-                with self._open_server_info(self._api.env.server, 'r') as sc:
-                    si = json.load(sc)
-
-                fp = si['fingerprint']
-                exp = si['expiration']
-            except Exception as e:
-                no_info = True
-                if not (isinstance(e, EnvironmentError) and
-                        e.errno == errno.ENOENT):  # pylint: disable=no-member
-                    logger.warning('Failed to load server properties: {}'
-                                   ''.format(e))
-
-        force_check = ((not getattr(self, '_schema_checked', False)) and
-                       self._api.env.force_schema_check)
-
-        if (force_check or
-                no_info or exp < time.time() or not Schema._in_cache(fp)):
-            (fp, exp) = self._get_schema()
-            self._schema_checked = True
-            _ensure_dir_created(SERVERS_DIR)
-            try:
-                with self._open_server_info(self._api.env.server, 'w') as sc:
-                    json.dump(dict(fingerprint=fp, expiration=exp), sc)
-            except Exception as e:
-                logger.warning('Failed to store server properties: {}'
-                               ''.format(e))
-
-        if not self._dict:
-            self._dict['fingerprint'] = fp
-            schema_info = self._read(self.schema_info_path)
+        self._fingerprint = fp
+        self._expiration = ttl + time.time()
+
+    def _read_schema(self):
+        with self._open_schema(self._fingerprint, 'r') as schema:
+            self._dict['fingerprint'] = self._get_schema_fingerprint(schema)
+            schema_info = json.loads(schema.read(self.schema_info_path))
             self._dict['version'] = schema_info['version']
-            for ns in self.namespaces:
-                self._dict[ns] = _SchemaNameSpace(self, ns)
 
-        self._server_schema_fingerprintr = fp
-        self._server_schema_expiration = exp
+            for name in schema.namelist():
+                ns, _slash, key = name.partition('/')
+                if ns in self.namespaces:
+                    self._dict[ns][key] = {}
 
     def __getitem__(self, key):
-        self._ensure_cached()
-        return self._dict[key]
+        try:
+            return self._namespaces[key]
+        except KeyError:
+            return self._dict[key]
 
-    def _open_archive(self, mode, fp=None):
-        if not fp:
-            fp = self['fingerprint']
-        arch_path = self.schema_path_template.format(fp)
-        return _LockedZipFile(arch_path, mode)
-
-    def _store(self, fingerprint, schema={}):
-        _ensure_dir_created(SCHEMA_DIR)
-
-        schema_info = dict(version=schema['version'],
-                           fingerprint=schema['fingerprint'])
-
-        with self._open_archive('w', fingerprint) as zf:
-            # store schema information
-            zf.writestr(self.schema_info_path, json.dumps(schema_info))
-            # store namespaces
-            for namespace in self.namespaces:
-                for member in schema[namespace]:
-                    path = self.ns_member_path_template.format(
-                        namespace,
-                        member['full_name']
-                    )
-                    zf.writestr(path, json.dumps(member))
+    def _write_schema(self):
+        try:
+            os.makedirs(self._DIR)
+        except EnvironmentError as e:
+            if e.errno != errno.EEXIST:
+                logger.warning("Failed ti write schema: {}".format(e))
+                return
+
+        with self._open_schema(self._fingerprint, 'w') as schema:
+            schema_info = {}
+            for key, value in self._dict.items():
+                if key in self.namespaces:
+                    ns = value
+                    for member in ns:
+                        path = '{}/{}'.format(key, member)
+                        schema.writestr(path, json.dumps(ns[member]))
+                else:
+                    schema_info[key] = value
+
+            schema.writestr(self.schema_info_path, json.dumps(schema_info))
 
     def _read(self, path):
-        with self._open_archive('r') as zf:
+        with self._open_schema(self._fingerprint, 'r') as zf:
             return json.loads(zf.read(path))
 
     def read_namespace_member(self, namespace, member):
-        path = self.ns_member_path_template.format(namespace, member)
-        return self._read(path)
+        value = self._dict[namespace][member]
+
+        if (not value) or ('full_name' not in value):
+            path = '{}/{}'.format(namespace, member)
+            value = self._dict[namespace].setdefault(
+                member, {}
+            ).update(self._read(path))
+
+        return value
 
     def iter_namespace(self, namespace):
-        pattern = self.ns_member_pattern_template.format(namespace)
-        with self._open_archive('r') as zf:
-            for name in zf.namelist():
-                r = re.match(pattern, name)
-                if r:
-                    yield r.groups('name')[0]
+        return iter(self._dict[namespace])
 
 
 def get_package(api, client):
     try:
         schema = api._schema
     except AttributeError:
-        schema = Schema(api, client)
-        object.__setattr__(api, '_schema', schema)
+        with ServerInfo(api.env.hostname) as server_info:
+            schema = Schema(api, server_info, client)
+            object.__setattr__(api, '_schema', schema)
 
     fingerprint = str(schema['fingerprint'])
     package_name = '{}${}'.format(__name__, fingerprint)
@@ -509,10 +563,9 @@ def get_package(api, client):
     module = types.ModuleType(module_name)
     module.__file__ = os.path.join(package_dir, 'plugins.py')
     module.register = plugable.Registry()
-    for key, plugin_cls in (('commands', _SchemaCommandPlugin),
-                            ('classes', _SchemaObjectPlugin)):
-        for full_name in schema[key]:
-            plugin = plugin_cls(full_name)
+    for plugin_cls in (_SchemaCommandPlugin, _SchemaObjectPlugin):
+        for full_name in schema[plugin_cls.schema_key]:
+            plugin = plugin_cls(str(full_name))
             plugin = module.register()(plugin)
     sys.modules[module_name] = module
 
-- 
2.7.4