pgreco / rpms / ipa

Forked from forks/areguera/rpms/ipa 4 years ago
Clone

Blame SOURCES/0002-Sort-and-shuffle-SRV-record-by-priority-and-weight.patch

f65af0
From a56b9ef37dde90a593da6adbae2525048c5c8627 Mon Sep 17 00:00:00 2001
979ee0
From: Christian Heimes <cheimes@redhat.com>
979ee0
Date: Fri, 15 Jun 2018 17:03:29 +0200
979ee0
Subject: [PATCH] Sort and shuffle SRV record by priority and weight
979ee0
979ee0
On multiple occasions, SRV query answers were not properly sorted by
979ee0
priority. Records with same priority weren't randomized and shuffled.
979ee0
This caused FreeIPA to contact the same remote peer instead of
979ee0
distributing the load across all available servers.
979ee0
979ee0
Two new helper functions now take care of SRV queries. sort_prio_weight()
979ee0
sorts SRV and URI records. query_srv() combines SRV lookup with
979ee0
sort_prio_weight().
979ee0
979ee0
Fixes: https://pagure.io/freeipa/issue/7475
979ee0
Signed-off-by: Christian Heimes <cheimes@redhat.com>
979ee0
Reviewed-By: Rob Crittenden <rcritten@redhat.com>
979ee0
---
f65af0
 ipaclient/install/ipadiscovery.py       |   5 +-
f65af0
 ipalib/rpc.py                           |  21 +++----
f65af0
 ipalib/util.py                          |  11 ++--
f65af0
 ipapython/config.py                     |   8 +--
f65af0
 ipapython/dnsutil.py                    |  92 ++++++++++++++++++++++++++-
979ee0
 ipaserver/dcerpc.py                     |   4 +-
f65af0
 ipatests/test_ipapython/test_dnsutil.py | 106 ++++++++++++++++++++++++++++++++
f65af0
 7 files changed, 217 insertions(+), 30 deletions(-)
979ee0
 create mode 100644 ipatests/test_ipapython/test_dnsutil.py
979ee0
979ee0
diff --git a/ipaclient/install/ipadiscovery.py b/ipaclient/install/ipadiscovery.py
f65af0
index 363970c86ba65b7ec49e403adc745fe61b2242aa..a8283d8e6ad070a8b6626a187ff55485bb74eba4 100644
979ee0
--- a/ipaclient/install/ipadiscovery.py
979ee0
+++ b/ipaclient/install/ipadiscovery.py
f65af0
@@ -20,7 +20,6 @@
f65af0
 from __future__ import absolute_import
f65af0
 
f65af0
 import logging
f65af0
-import operator
f65af0
 import socket
f65af0
 
f65af0
 import six
f65af0
@@ -28,6 +27,7 @@ import six
979ee0
 from dns import resolver, rdatatype
979ee0
 from dns.exception import DNSException
979ee0
 from ipalib import errors
979ee0
+from ipapython.dnsutil import query_srv
979ee0
 from ipapython import ipaldap
979ee0
 from ipaplatform.paths import paths
979ee0
 from ipapython.ipautil import valid_ip, realm_to_suffix
f65af0
@@ -498,8 +498,7 @@ class IPADiscovery(object):
f65af0
         logger.debug("Search DNS for SRV record of %s", qname)
979ee0
 
979ee0
         try:
979ee0
-            answers = resolver.query(qname, rdatatype.SRV)
f65af0
-            answers = sorted(answers, key=operator.attrgetter('priority'))
979ee0
+            answers = query_srv(qname)
979ee0
         except DNSException as e:
f65af0
             logger.debug("DNS record not found: %s", e.__class__.__name__)
979ee0
             answers = []
979ee0
diff --git a/ipalib/rpc.py b/ipalib/rpc.py
f65af0
index c6a8989f5dc157f4c4e59637e6e7d114c5fa952c..17368f160148d6b821ec4934d31b4b4034a7e67b 100644
979ee0
--- a/ipalib/rpc.py
979ee0
+++ b/ipalib/rpc.py
f65af0
@@ -45,7 +45,6 @@ import gzip
f65af0
 from cryptography import x509 as crypto_x509
979ee0
 
979ee0
 import gssapi
979ee0
-from dns import resolver, rdatatype
979ee0
 from dns.exception import DNSException
979ee0
 from ssl import SSLError
979ee0
 import six
f65af0
@@ -61,7 +60,7 @@ from ipalib.x509 import Encoding as x509_Encoding
979ee0
 from ipapython import ipautil
979ee0
 from ipapython import session_storage
979ee0
 from ipapython.cookie import Cookie
979ee0
-from ipapython.dnsutil import DNSName
979ee0
+from ipapython.dnsutil import DNSName, query_srv
979ee0
 from ipalib.text import _
979ee0
 from ipalib.util import create_https_connection
979ee0
 from ipalib.krb_utils import KRB5KDC_ERR_S_PRINCIPAL_UNKNOWN, KRB5KRB_AP_ERR_TKT_EXPIRED, \
f65af0
@@ -878,7 +877,7 @@ class RPCClient(Connectible):
979ee0
         name = '_ldap._tcp.%s.' % self.env.domain
979ee0
 
979ee0
         try:
979ee0
-            answers = resolver.query(name, rdatatype.SRV)
979ee0
+            answers = query_srv(name)
979ee0
         except DNSException:
979ee0
             answers = []
979ee0
 
f65af0
@@ -886,17 +885,11 @@ class RPCClient(Connectible):
979ee0
             server = str(answer.target).rstrip(".")
979ee0
             servers.append('https://%s%s' % (ipautil.format_netloc(server), path))
979ee0
 
979ee0
-        servers = list(set(servers))
979ee0
-        # the list/set conversion won't preserve order so stick in the
979ee0
-        # local config file version here.
979ee0
-        cfg_server = rpc_uri
979ee0
-        if cfg_server in servers:
979ee0
-            # make sure the configured master server is there just once and
979ee0
-            # it is the first one
979ee0
-            servers.remove(cfg_server)
979ee0
-            servers.insert(0, cfg_server)
979ee0
-        else:
979ee0
-            servers.insert(0, cfg_server)
979ee0
+        # make sure the configured master server is there just once and
979ee0
+        # it is the first one.
979ee0
+        if rpc_uri in servers:
979ee0
+            servers.remove(rpc_uri)
979ee0
+        servers.insert(0, rpc_uri)
979ee0
 
979ee0
         return servers
979ee0
 
979ee0
diff --git a/ipalib/util.py b/ipalib/util.py
f65af0
index ebf6eb3faf91cefc02514afa84ad9b63d0f82b0b..592821f9ff4f9c16fb2589d697e3e202443d0eba 100644
979ee0
--- a/ipalib/util.py
979ee0
+++ b/ipalib/util.py
f65af0
@@ -973,14 +973,13 @@ def detect_dns_zone_realm_type(api, domain):
979ee0
 
979ee0
     try:
979ee0
         # The presence of this record is enough, return foreign in such case
979ee0
-        result = resolver.query(ad_specific_record_name, rdatatype.SRV)
979ee0
-        return 'foreign'
979ee0
-
979ee0
+        resolver.query(ad_specific_record_name, rdatatype.SRV)
979ee0
     except DNSException:
979ee0
-        pass
979ee0
+        # If we could not detect type with certainty, return unknown
979ee0
+        return 'unknown'
979ee0
+    else:
979ee0
+        return 'foreign'
979ee0
 
979ee0
-    # If we could not detect type with certainity, return unknown
979ee0
-    return 'unknown'
979ee0
 
979ee0
 def has_managed_topology(api):
979ee0
     domainlevel = api.Command['domainlevel_get']().get('result', DOMAIN_LEVEL_0)
979ee0
diff --git a/ipapython/config.py b/ipapython/config.py
f65af0
index c3360779f1320e7d41f4d98b317d930a1d0b8242..f701122bd4828693a41848668e49052912042837 100644
979ee0
--- a/ipapython/config.py
979ee0
+++ b/ipapython/config.py
f65af0
@@ -26,7 +26,6 @@ from copy import copy
979ee0
 import socket
f65af0
 import functools
979ee0
 
979ee0
-from dns import resolver, rdatatype
979ee0
 from dns.exception import DNSException
979ee0
 import dns.name
979ee0
 # pylint: disable=import-error
f65af0
@@ -36,6 +35,7 @@ from six.moves.urllib.parse import urlsplit
979ee0
 
f65af0
 from ipaplatform.paths import paths
979ee0
 from ipapython.dn import DN
979ee0
+from ipapython.dnsutil import query_srv
f65af0
 from ipapython.ipautil import CheckedIPAddress, CheckedIPAddressLoopback
f65af0
 
979ee0
 
f65af0
@@ -210,7 +210,7 @@ def __discover_config(discover_server = True):
979ee0
             name = "_ldap._tcp." + domain
979ee0
 
979ee0
             try:
979ee0
-                servers = resolver.query(name, rdatatype.SRV)
979ee0
+                servers = query_srv(name)
979ee0
             except DNSException:
979ee0
                 # try cycling on domain components of FQDN
979ee0
                 try:
f65af0
@@ -225,7 +225,7 @@ def __discover_config(discover_server = True):
979ee0
                         return False
979ee0
                     name = "_ldap._tcp.%s" % domain
979ee0
                     try:
979ee0
-                        servers = resolver.query(name, rdatatype.SRV)
979ee0
+                        servers = query_srv(name)
979ee0
                         break
979ee0
                     except DNSException:
979ee0
                         pass
f65af0
@@ -236,7 +236,7 @@ def __discover_config(discover_server = True):
979ee0
             if not servers:
979ee0
                 name = "_ldap._tcp.%s." % config.default_domain
979ee0
                 try:
979ee0
-                    servers = resolver.query(name, rdatatype.SRV)
979ee0
+                    servers = query_srv(name)
979ee0
                 except DNSException:
979ee0
                     pass
979ee0
 
979ee0
diff --git a/ipapython/dnsutil.py b/ipapython/dnsutil.py
f65af0
index b40302d0efbb32108626d2cc1216ad4858343407..6157183a0fa5802a5fb772f078ee4c1688857fc8 100644
979ee0
--- a/ipapython/dnsutil.py
979ee0
+++ b/ipapython/dnsutil.py
f65af0
@@ -17,12 +17,17 @@
979ee0
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
979ee0
 #
979ee0
 
979ee0
+import copy
f65af0
 import logging
979ee0
+import operator
979ee0
+import random
f65af0
 
979ee0
 import dns.name
979ee0
 import dns.exception
979ee0
 import dns.resolver
979ee0
-import copy
979ee0
+import dns.rdataclass
979ee0
+import dns.rdatatype
f65af0
+
979ee0
 
979ee0
 import six
979ee0
 
f65af0
@@ -373,3 +378,88 @@ def check_zone_overlap(zone, raise_on_error=True):
979ee0
         if ns:
979ee0
             msg += u" and is handled by server(s): {0}".format(', '.join(ns))
979ee0
         raise ValueError(msg)
979ee0
+
979ee0
+
979ee0
+def _mix_weight(records):
979ee0
+    """Weighted population sorting for records with same priority
979ee0
+    """
979ee0
+    # trivial case
979ee0
+    if len(records) <= 1:
979ee0
+        return records
979ee0
+
979ee0
+    # Optimization for common case: If all weights are the same (e.g. 0),
979ee0
+    # just shuffle the records, which is about four times faster.
979ee0
+    if all(rr.weight == records[0].weight for rr in records):
979ee0
+        random.shuffle(records)
979ee0
+        return records
979ee0
+
979ee0
+    noweight = 0.01  # give records with 0 weight a small chance
979ee0
+    result = []
979ee0
+    records = set(records)
979ee0
+    while len(records) > 1:
979ee0
+        # Compute the sum of the weights of those RRs. Then choose a
979ee0
+        # uniform random number between 0 and the sum computed (inclusive).
979ee0
+        urn = random.uniform(0, sum(rr.weight or noweight for rr in records))
979ee0
+        # Select the RR whose running sum value is the first in the selected
979ee0
+        # order which is greater than or equal to the random number selected.
979ee0
+        acc = 0.
979ee0
+        for rr in records.copy():
979ee0
+            acc += rr.weight or noweight
979ee0
+            if acc >= urn:
979ee0
+                records.remove(rr)
979ee0
+                result.append(rr)
979ee0
+    if records:
979ee0
+        result.append(records.pop())
979ee0
+    return result
979ee0
+
979ee0
+
979ee0
+def sort_prio_weight(records):
979ee0
+    """RFC 2782 sorting algorithm for SRV and URI records
979ee0
+
979ee0
+    RFC 2782 defines a sorting algorithms for SRV records, that is also used
979ee0
+    for sorting URI records. Records are sorted by priority and than randomly
979ee0
+    shuffled according to weight.
979ee0
+
979ee0
+    This implementation also removes duplicate entries.
979ee0
+    """
979ee0
+    # order records by priority
979ee0
+    records = sorted(records, key=operator.attrgetter("priority"))
979ee0
+
979ee0
+    # remove duplicate entries
979ee0
+    uniquerecords = []
979ee0
+    seen = set()
979ee0
+    for rr in records:
979ee0
+        # A SRV record has target and port, URI just has target.
979ee0
+        target = (rr.target, getattr(rr, "port", None))
979ee0
+        if target not in seen:
979ee0
+            uniquerecords.append(rr)
979ee0
+            seen.add(target)
979ee0
+
979ee0
+    # weighted randomization of entries with same priority
979ee0
+    result = []
979ee0
+    sameprio = []
979ee0
+    for rr in uniquerecords:
979ee0
+        # add all items with same priority in a bucket
979ee0
+        if not sameprio or sameprio[0].priority == rr.priority:
979ee0
+            sameprio.append(rr)
979ee0
+        else:
979ee0
+            # got different priority, shuffle bucket
979ee0
+            result.extend(_mix_weight(sameprio))
979ee0
+            # start a new priority list
979ee0
+            sameprio = [rr]
979ee0
+    # add last batch of records with same priority
979ee0
+    if sameprio:
979ee0
+        result.extend(_mix_weight(sameprio))
979ee0
+    return result
979ee0
+
979ee0
+
979ee0
+def query_srv(qname, resolver=None, **kwargs):
979ee0
+    """Query SRV records and sort reply according to RFC 2782
979ee0
+
979ee0
+    :param qname: query name, _service._proto.domain.
979ee0
+    :return: list of dns.rdtypes.IN.SRV.SRV instances
979ee0
+    """
979ee0
+    if resolver is None:
979ee0
+        resolver = dns.resolver
979ee0
+    answer = resolver.query(qname, rdtype=dns.rdatatype.SRV, **kwargs)
979ee0
+    return sort_prio_weight(answer)
979ee0
diff --git a/ipaserver/dcerpc.py b/ipaserver/dcerpc.py
f65af0
index e3aa9f6a6c239f8606f09bde06f21cb9cf64eb14..1e2d6fbbb5372bea7c6d68b73acededfe714a56a 100644
979ee0
--- a/ipaserver/dcerpc.py
979ee0
+++ b/ipaserver/dcerpc.py
f65af0
@@ -32,6 +32,7 @@ from ipalib import api, _
f65af0
 from ipalib import errors
979ee0
 from ipapython import ipautil
979ee0
 from ipapython.dn import DN
979ee0
+from ipapython.dnsutil import query_srv
f65af0
 from ipapython.ipaldap import ldap_initialize
979ee0
 from ipaserver.install import installutils
979ee0
 from ipaserver.dcerpc_common import (TRUST_BIDIRECTIONAL,
979ee0
@@ -55,7 +56,6 @@ import samba
979ee0
 import ldap as _ldap
979ee0
 from ipapython import ipaldap
979ee0
 from ipapython.dnsutil import DNSName
979ee0
-from dns import resolver, rdatatype
979ee0
 from dns.exception import DNSException
979ee0
 import pysss_nss_idmap
979ee0
 import pysss
f65af0
@@ -802,7 +802,7 @@ class DomainValidator(object):
979ee0
             gc_name = '_gc._tcp.%s.' % info['dns_domain']
979ee0
 
979ee0
             try:
979ee0
-                answers = resolver.query(gc_name, rdatatype.SRV)
979ee0
+                answers = query_srv(gc_name)
979ee0
             except DNSException as e:
979ee0
                 answers = []
979ee0
 
979ee0
diff --git a/ipatests/test_ipapython/test_dnsutil.py b/ipatests/test_ipapython/test_dnsutil.py
979ee0
new file mode 100644
979ee0
index 0000000000000000000000000000000000000000..36adb077cf38f6d036aa1048b201dee7d08eb310
979ee0
--- /dev/null
979ee0
+++ b/ipatests/test_ipapython/test_dnsutil.py
979ee0
@@ -0,0 +1,106 @@
979ee0
+#
979ee0
+# Copyright (C) 2018  FreeIPA Contributors.  See COPYING for license
979ee0
+#
979ee0
+import dns.name
979ee0
+import dns.rdataclass
979ee0
+import dns.rdatatype
979ee0
+from dns.rdtypes.IN.SRV import SRV
979ee0
+from dns.rdtypes.ANY.URI import URI
979ee0
+
979ee0
+from ipapython import dnsutil
979ee0
+
979ee0
+import pytest
979ee0
+
979ee0
+
979ee0
+def mksrv(priority, weight, port, target):
979ee0
+    return SRV(
979ee0
+        rdclass=dns.rdataclass.IN,
979ee0
+        rdtype=dns.rdatatype.SRV,
979ee0
+        priority=priority,
979ee0
+        weight=weight,
979ee0
+        port=port,
979ee0
+        target=dns.name.from_text(target)
979ee0
+    )
979ee0
+
979ee0
+
979ee0
+def mkuri(priority, weight, target):
979ee0
+    return URI(
979ee0
+        rdclass=dns.rdataclass.IN,
979ee0
+        rdtype=dns.rdatatype.URI,
979ee0
+        priority=priority,
979ee0
+        weight=weight,
979ee0
+        target=target
979ee0
+    )
979ee0
+
979ee0
+
979ee0
+class TestSortSRV(object):
979ee0
+    def test_empty(self):
979ee0
+        assert dnsutil.sort_prio_weight([]) == []
979ee0
+
979ee0
+    def test_one(self):
979ee0
+        h1 = mksrv(1, 0, 443, u"host1")
979ee0
+        assert dnsutil.sort_prio_weight([h1]) == [h1]
979ee0
+
979ee0
+        h2 = mksrv(10, 5, 443, u"host2")
979ee0
+        assert dnsutil.sort_prio_weight([h2]) == [h2]
979ee0
+
979ee0
+    def test_prio(self):
979ee0
+        h1 = mksrv(1, 0, 443, u"host1")
979ee0
+        h2 = mksrv(2, 0, 443, u"host2")
979ee0
+        h3 = mksrv(3, 0, 443, u"host3")
979ee0
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
979ee0
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
979ee0
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
979ee0
+
979ee0
+        h380 = mksrv(4, 0, 80, u"host3")
979ee0
+        assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380]
979ee0
+
979ee0
+        hs = mksrv(-1, 0, 443, u"special")
979ee0
+        assert dnsutil.sort_prio_weight([h1, h2, hs]) == [hs, h1, h2]
979ee0
+
979ee0
+    def assert_permutations(self, answers, permutations):
979ee0
+        seen = set()
979ee0
+        for _unused in range(1000):
979ee0
+            result = tuple(dnsutil.sort_prio_weight(answers))
979ee0
+            assert result in permutations
979ee0
+            seen.add(result)
979ee0
+            if seen == permutations:
979ee0
+                break
979ee0
+        else:
979ee0
+            pytest.fail("sorting didn't exhaust all permutations.")
979ee0
+
979ee0
+    def test_sameprio(self):
979ee0
+        h1 = mksrv(1, 0, 443, u"host1")
979ee0
+        h2 = mksrv(1, 0, 443, u"host2")
979ee0
+        permutations = {
979ee0
+            (h1, h2),
979ee0
+            (h2, h1),
979ee0
+        }
979ee0
+        self.assert_permutations([h1, h2], permutations)
979ee0
+
979ee0
+    def test_weight(self):
979ee0
+        h1 = mksrv(1, 0, 443, u"host1")
979ee0
+        h2_w15 = mksrv(2, 15, 443, u"host2")
979ee0
+        h3_w10 = mksrv(2, 10, 443, u"host3")
979ee0
+
979ee0
+        permutations = {
979ee0
+            (h1, h2_w15, h3_w10),
979ee0
+            (h1, h3_w10, h2_w15),
979ee0
+        }
979ee0
+        self.assert_permutations([h1, h2_w15, h3_w10], permutations)
979ee0
+
979ee0
+    def test_large(self):
979ee0
+        records = tuple(
979ee0
+            mksrv(1, i, 443, "host{}".format(i)) for i in range(1000)
979ee0
+        )
979ee0
+        assert len(dnsutil.sort_prio_weight(records)) == len(records)
979ee0
+
979ee0
+
979ee0
+class TestSortURI(object):
979ee0
+    def test_prio(self):
979ee0
+        h1 = mkuri(1, 0, u"https://host1/api")
979ee0
+        h2 = mkuri(2, 0, u"https://host2/api")
979ee0
+        h3 = mkuri(3, 0, u"https://host3/api")
979ee0
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
979ee0
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
979ee0
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
979ee0
-- 
f65af0
2.14.4
979ee0