2737e7
From 9e2b0c4b026f281507d1879da8398841270bc20f Mon Sep 17 00:00:00 2001
2737e7
From: Christian Heimes <cheimes@redhat.com>
2737e7
Date: Fri, 15 Jun 2018 17:03:29 +0200
2737e7
Subject: [PATCH] Sort and shuffle SRV record by priority and weight
2737e7
2737e7
On multiple occasions, SRV query answers were not properly sorted by
2737e7
priority. Records with same priority weren't randomized and shuffled.
2737e7
This caused FreeIPA to contact the same remote peer instead of
2737e7
distributing the load across all available servers.
2737e7
2737e7
Two new helper functions now take care of SRV queries. sort_prio_weight()
2737e7
sorts SRV and URI records. query_srv() combines SRV lookup with
2737e7
sort_prio_weight().
2737e7
2737e7
Fixes: https://pagure.io/freeipa/issue/7475
2737e7
Signed-off-by: Christian Heimes <cheimes@redhat.com>
2737e7
Reviewed-By: Rob Crittenden <rcritten@redhat.com>
2737e7
---
2737e7
 ipaclient/install/ipadiscovery.py       |   3 +-
2737e7
 ipalib/rpc.py                           |  21 ++---
2737e7
 ipalib/util.py                          |  11 ++-
2737e7
 ipapython/config.py                     |   8 +-
2737e7
 ipapython/dnsutil.py                    |  92 +++++++++++++++++++-
2737e7
 ipaserver/dcerpc.py                     |   4 +-
2737e7
 ipatests/test_ipapython/test_dnsutil.py | 106 ++++++++++++++++++++++++
2737e7
 7 files changed, 217 insertions(+), 28 deletions(-)
2737e7
 create mode 100644 ipatests/test_ipapython/test_dnsutil.py
2737e7
2737e7
diff --git a/ipaclient/install/ipadiscovery.py b/ipaclient/install/ipadiscovery.py
2737e7
index 46e05c971647b4f0fb4e6044ef74aff3e7919632..34142179a9f4957e842769d9d4036d2024130793 100644
2737e7
--- a/ipaclient/install/ipadiscovery.py
2737e7
+++ b/ipaclient/install/ipadiscovery.py
2737e7
@@ -25,6 +25,7 @@ from ipapython.ipa_log_manager import root_logger
2737e7
 from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 from ipalib import errors
2737e7
+from ipapython.dnsutil import query_srv
2737e7
 from ipapython import ipaldap
2737e7
 from ipaplatform.paths import paths
2737e7
 from ipapython.ipautil import valid_ip, realm_to_suffix
2737e7
@@ -492,7 +493,7 @@ class IPADiscovery(object):
2737e7
         root_logger.debug("Search DNS for SRV record of %s", qname)
2737e7
 
2737e7
         try:
2737e7
-            answers = resolver.query(qname, rdatatype.SRV)
2737e7
+            answers = query_srv(qname)
2737e7
         except DNSException as e:
2737e7
             root_logger.debug("DNS record not found: %s", e.__class__.__name__)
2737e7
             answers = []
2737e7
diff --git a/ipalib/rpc.py b/ipalib/rpc.py
2737e7
index e3b8d67d69c084ad1a43390b5f93061826a27e1d..e74807d57955cd36aa8622b4441e08ee89cd313e 100644
2737e7
--- a/ipalib/rpc.py
2737e7
+++ b/ipalib/rpc.py
2737e7
@@ -43,7 +43,6 @@ import socket
2737e7
 import gzip
2737e7
 
2737e7
 import gssapi
2737e7
-from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 from ssl import SSLError
2737e7
 import six
2737e7
@@ -59,7 +58,7 @@ from ipapython.ipa_log_manager import root_logger
2737e7
 from ipapython import ipautil
2737e7
 from ipapython import session_storage
2737e7
 from ipapython.cookie import Cookie
2737e7
-from ipapython.dnsutil import DNSName
2737e7
+from ipapython.dnsutil import DNSName, query_srv
2737e7
 from ipalib.text import _
2737e7
 from ipalib.util import create_https_connection
2737e7
 from ipalib.krb_utils import KRB5KDC_ERR_S_PRINCIPAL_UNKNOWN, KRB5KRB_AP_ERR_TKT_EXPIRED, \
2737e7
@@ -853,7 +852,7 @@ class RPCClient(Connectible):
2737e7
         name = '_ldap._tcp.%s.' % self.env.domain
2737e7
 
2737e7
         try:
2737e7
-            answers = resolver.query(name, rdatatype.SRV)
2737e7
+            answers = query_srv(name)
2737e7
         except DNSException:
2737e7
             answers = []
2737e7
 
2737e7
@@ -861,17 +860,11 @@ class RPCClient(Connectible):
2737e7
             server = str(answer.target).rstrip(".")
2737e7
             servers.append('https://%s%s' % (ipautil.format_netloc(server), path))
2737e7
 
2737e7
-        servers = list(set(servers))
2737e7
-        # the list/set conversion won't preserve order so stick in the
2737e7
-        # local config file version here.
2737e7
-        cfg_server = rpc_uri
2737e7
-        if cfg_server in servers:
2737e7
-            # make sure the configured master server is there just once and
2737e7
-            # it is the first one
2737e7
-            servers.remove(cfg_server)
2737e7
-            servers.insert(0, cfg_server)
2737e7
-        else:
2737e7
-            servers.insert(0, cfg_server)
2737e7
+        # make sure the configured master server is there just once and
2737e7
+        # it is the first one.
2737e7
+        if rpc_uri in servers:
2737e7
+            servers.remove(rpc_uri)
2737e7
+        servers.insert(0, rpc_uri)
2737e7
 
2737e7
         return servers
2737e7
 
2737e7
diff --git a/ipalib/util.py b/ipalib/util.py
2737e7
index 6ee65498b4de674fe4b2ee361541d3bfe648bba0..56db48638e8319859850fba449ed7c23b6e909ab 100644
2737e7
--- a/ipalib/util.py
2737e7
+++ b/ipalib/util.py
2737e7
@@ -934,14 +934,13 @@ def detect_dns_zone_realm_type(api, domain):
2737e7
 
2737e7
     try:
2737e7
         # The presence of this record is enough, return foreign in such case
2737e7
-        result = resolver.query(ad_specific_record_name, rdatatype.SRV)
2737e7
-        return 'foreign'
2737e7
-
2737e7
+        resolver.query(ad_specific_record_name, rdatatype.SRV)
2737e7
     except DNSException:
2737e7
-        pass
2737e7
+        # If we could not detect type with certainty, return unknown
2737e7
+        return 'unknown'
2737e7
+    else:
2737e7
+        return 'foreign'
2737e7
 
2737e7
-    # If we could not detect type with certainity, return unknown
2737e7
-    return 'unknown'
2737e7
 
2737e7
 def has_managed_topology(api):
2737e7
     domainlevel = api.Command['domainlevel_get']().get('result', DOMAIN_LEVEL_0)
2737e7
diff --git a/ipapython/config.py b/ipapython/config.py
2737e7
index 19abfc51ee354d2971be836fa6bad70eea3a6720..44c823b6b946c28a510e5f156061eba0b05aa059 100644
2737e7
--- a/ipapython/config.py
2737e7
+++ b/ipapython/config.py
2737e7
@@ -24,7 +24,6 @@ from optparse import (
2737e7
 from copy import copy
2737e7
 import socket
2737e7
 
2737e7
-from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 import dns.name
2737e7
 # pylint: disable=import-error
2737e7
@@ -33,6 +32,7 @@ from six.moves.urllib.parse import urlsplit
2737e7
 # pylint: enable=import-error
2737e7
 
2737e7
 from ipapython.dn import DN
2737e7
+from ipapython.dnsutil import query_srv
2737e7
 
2737e7
 try:
2737e7
     # pylint: disable=ipa-forbidden-import
2737e7
@@ -195,7 +195,7 @@ def __discover_config(discover_server = True):
2737e7
             name = "_ldap._tcp." + domain
2737e7
 
2737e7
             try:
2737e7
-                servers = resolver.query(name, rdatatype.SRV)
2737e7
+                servers = query_srv(name)
2737e7
             except DNSException:
2737e7
                 # try cycling on domain components of FQDN
2737e7
                 try:
2737e7
@@ -210,7 +210,7 @@ def __discover_config(discover_server = True):
2737e7
                         return False
2737e7
                     name = "_ldap._tcp.%s" % domain
2737e7
                     try:
2737e7
-                        servers = resolver.query(name, rdatatype.SRV)
2737e7
+                        servers = query_srv(name)
2737e7
                         break
2737e7
                     except DNSException:
2737e7
                         pass
2737e7
@@ -221,7 +221,7 @@ def __discover_config(discover_server = True):
2737e7
             if not servers:
2737e7
                 name = "_ldap._tcp.%s." % config.default_domain
2737e7
                 try:
2737e7
-                    servers = resolver.query(name, rdatatype.SRV)
2737e7
+                    servers = query_srv(name)
2737e7
                 except DNSException:
2737e7
                     pass
2737e7
 
2737e7
diff --git a/ipapython/dnsutil.py b/ipapython/dnsutil.py
2737e7
index 011b722dac3e181ac52f7d92d9f44d31c5e2e6bb..25435ba51e6e7c2c6581b60eb077dd133dd29724 100644
2737e7
--- a/ipapython/dnsutil.py
2737e7
+++ b/ipapython/dnsutil.py
2737e7
@@ -17,10 +17,15 @@
2737e7
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
2737e7
 #
2737e7
 
2737e7
+import copy
2737e7
+import operator
2737e7
+import random
2737e7
+
2737e7
 import dns.name
2737e7
 import dns.exception
2737e7
 import dns.resolver
2737e7
-import copy
2737e7
+import dns.rdataclass
2737e7
+import dns.rdatatype
2737e7
 
2737e7
 import six
2737e7
 
2737e7
@@ -369,3 +374,88 @@ def check_zone_overlap(zone, raise_on_error=True):
2737e7
         if ns:
2737e7
             msg += u" and is handled by server(s): {0}".format(', '.join(ns))
2737e7
         raise ValueError(msg)
2737e7
+
2737e7
+
2737e7
+def _mix_weight(records):
2737e7
+    """Weighted population sorting for records with same priority
2737e7
+    """
2737e7
+    # trivial case
2737e7
+    if len(records) <= 1:
2737e7
+        return records
2737e7
+
2737e7
+    # Optimization for common case: If all weights are the same (e.g. 0),
2737e7
+    # just shuffle the records, which is about four times faster.
2737e7
+    if all(rr.weight == records[0].weight for rr in records):
2737e7
+        random.shuffle(records)
2737e7
+        return records
2737e7
+
2737e7
+    noweight = 0.01  # give records with 0 weight a small chance
2737e7
+    result = []
2737e7
+    records = set(records)
2737e7
+    while len(records) > 1:
2737e7
+        # Compute the sum of the weights of those RRs. Then choose a
2737e7
+        # uniform random number between 0 and the sum computed (inclusive).
2737e7
+        urn = random.uniform(0, sum(rr.weight or noweight for rr in records))
2737e7
+        # Select the RR whose running sum value is the first in the selected
2737e7
+        # order which is greater than or equal to the random number selected.
2737e7
+        acc = 0.
2737e7
+        for rr in records.copy():
2737e7
+            acc += rr.weight or noweight
2737e7
+            if acc >= urn:
2737e7
+                records.remove(rr)
2737e7
+                result.append(rr)
2737e7
+    if records:
2737e7
+        result.append(records.pop())
2737e7
+    return result
2737e7
+
2737e7
+
2737e7
+def sort_prio_weight(records):
2737e7
+    """RFC 2782 sorting algorithm for SRV and URI records
2737e7
+
2737e7
+    RFC 2782 defines a sorting algorithms for SRV records, that is also used
2737e7
+    for sorting URI records. Records are sorted by priority and than randomly
2737e7
+    shuffled according to weight.
2737e7
+
2737e7
+    This implementation also removes duplicate entries.
2737e7
+    """
2737e7
+    # order records by priority
2737e7
+    records = sorted(records, key=operator.attrgetter("priority"))
2737e7
+
2737e7
+    # remove duplicate entries
2737e7
+    uniquerecords = []
2737e7
+    seen = set()
2737e7
+    for rr in records:
2737e7
+        # A SRV record has target and port, URI just has target.
2737e7
+        target = (rr.target, getattr(rr, "port", None))
2737e7
+        if target not in seen:
2737e7
+            uniquerecords.append(rr)
2737e7
+            seen.add(target)
2737e7
+
2737e7
+    # weighted randomization of entries with same priority
2737e7
+    result = []
2737e7
+    sameprio = []
2737e7
+    for rr in uniquerecords:
2737e7
+        # add all items with same priority in a bucket
2737e7
+        if not sameprio or sameprio[0].priority == rr.priority:
2737e7
+            sameprio.append(rr)
2737e7
+        else:
2737e7
+            # got different priority, shuffle bucket
2737e7
+            result.extend(_mix_weight(sameprio))
2737e7
+            # start a new priority list
2737e7
+            sameprio = [rr]
2737e7
+    # add last batch of records with same priority
2737e7
+    if sameprio:
2737e7
+        result.extend(_mix_weight(sameprio))
2737e7
+    return result
2737e7
+
2737e7
+
2737e7
+def query_srv(qname, resolver=None, **kwargs):
2737e7
+    """Query SRV records and sort reply according to RFC 2782
2737e7
+
2737e7
+    :param qname: query name, _service._proto.domain.
2737e7
+    :return: list of dns.rdtypes.IN.SRV.SRV instances
2737e7
+    """
2737e7
+    if resolver is None:
2737e7
+        resolver = dns.resolver
2737e7
+    answer = resolver.query(qname, rdtype=dns.rdatatype.SRV, **kwargs)
2737e7
+    return sort_prio_weight(answer)
2737e7
diff --git a/ipaserver/dcerpc.py b/ipaserver/dcerpc.py
2737e7
index ac1b2a34784df491a3851aa21bbadbec2297241c..4e957b19292f51a7f6e3540dc38590737c7ae5e4 100644
2737e7
--- a/ipaserver/dcerpc.py
2737e7
+++ b/ipaserver/dcerpc.py
2737e7
@@ -30,6 +30,7 @@ from ipalib import errors
2737e7
 from ipapython import ipautil
2737e7
 from ipapython.ipa_log_manager import root_logger
2737e7
 from ipapython.dn import DN
2737e7
+from ipapython.dnsutil import query_srv
2737e7
 from ipaserver.install import installutils
2737e7
 from ipaserver.dcerpc_common import (TRUST_BIDIRECTIONAL,
2737e7
                                      TRUST_JOIN_EXTERNAL,
2737e7
@@ -55,7 +56,6 @@ import samba
2737e7
 import ldap as _ldap
2737e7
 from ipapython import ipaldap
2737e7
 from ipapython.dnsutil import DNSName
2737e7
-from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 import pysss_nss_idmap
2737e7
 import pysss
2737e7
@@ -795,7 +795,7 @@ class DomainValidator(object):
2737e7
             gc_name = '_gc._tcp.%s.' % info['dns_domain']
2737e7
 
2737e7
             try:
2737e7
-                answers = resolver.query(gc_name, rdatatype.SRV)
2737e7
+                answers = query_srv(gc_name)
2737e7
             except DNSException as e:
2737e7
                 answers = []
2737e7
 
2737e7
diff --git a/ipatests/test_ipapython/test_dnsutil.py b/ipatests/test_ipapython/test_dnsutil.py
2737e7
new file mode 100644
2737e7
index 0000000000000000000000000000000000000000..36adb077cf38f6d036aa1048b201dee7d08eb310
2737e7
--- /dev/null
2737e7
+++ b/ipatests/test_ipapython/test_dnsutil.py
2737e7
@@ -0,0 +1,106 @@
2737e7
+#
2737e7
+# Copyright (C) 2018  FreeIPA Contributors.  See COPYING for license
2737e7
+#
2737e7
+import dns.name
2737e7
+import dns.rdataclass
2737e7
+import dns.rdatatype
2737e7
+from dns.rdtypes.IN.SRV import SRV
2737e7
+from dns.rdtypes.ANY.URI import URI
2737e7
+
2737e7
+from ipapython import dnsutil
2737e7
+
2737e7
+import pytest
2737e7
+
2737e7
+
2737e7
+def mksrv(priority, weight, port, target):
2737e7
+    return SRV(
2737e7
+        rdclass=dns.rdataclass.IN,
2737e7
+        rdtype=dns.rdatatype.SRV,
2737e7
+        priority=priority,
2737e7
+        weight=weight,
2737e7
+        port=port,
2737e7
+        target=dns.name.from_text(target)
2737e7
+    )
2737e7
+
2737e7
+
2737e7
+def mkuri(priority, weight, target):
2737e7
+    return URI(
2737e7
+        rdclass=dns.rdataclass.IN,
2737e7
+        rdtype=dns.rdatatype.URI,
2737e7
+        priority=priority,
2737e7
+        weight=weight,
2737e7
+        target=target
2737e7
+    )
2737e7
+
2737e7
+
2737e7
+class TestSortSRV(object):
2737e7
+    def test_empty(self):
2737e7
+        assert dnsutil.sort_prio_weight([]) == []
2737e7
+
2737e7
+    def test_one(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        assert dnsutil.sort_prio_weight([h1]) == [h1]
2737e7
+
2737e7
+        h2 = mksrv(10, 5, 443, u"host2")
2737e7
+        assert dnsutil.sort_prio_weight([h2]) == [h2]
2737e7
+
2737e7
+    def test_prio(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        h2 = mksrv(2, 0, 443, u"host2")
2737e7
+        h3 = mksrv(3, 0, 443, u"host3")
2737e7
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
2737e7
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
2737e7
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
2737e7
+
2737e7
+        h380 = mksrv(4, 0, 80, u"host3")
2737e7
+        assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380]
2737e7
+
2737e7
+        hs = mksrv(-1, 0, 443, u"special")
2737e7
+        assert dnsutil.sort_prio_weight([h1, h2, hs]) == [hs, h1, h2]
2737e7
+
2737e7
+    def assert_permutations(self, answers, permutations):
2737e7
+        seen = set()
2737e7
+        for _unused in range(1000):
2737e7
+            result = tuple(dnsutil.sort_prio_weight(answers))
2737e7
+            assert result in permutations
2737e7
+            seen.add(result)
2737e7
+            if seen == permutations:
2737e7
+                break
2737e7
+        else:
2737e7
+            pytest.fail("sorting didn't exhaust all permutations.")
2737e7
+
2737e7
+    def test_sameprio(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        h2 = mksrv(1, 0, 443, u"host2")
2737e7
+        permutations = {
2737e7
+            (h1, h2),
2737e7
+            (h2, h1),
2737e7
+        }
2737e7
+        self.assert_permutations([h1, h2], permutations)
2737e7
+
2737e7
+    def test_weight(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        h2_w15 = mksrv(2, 15, 443, u"host2")
2737e7
+        h3_w10 = mksrv(2, 10, 443, u"host3")
2737e7
+
2737e7
+        permutations = {
2737e7
+            (h1, h2_w15, h3_w10),
2737e7
+            (h1, h3_w10, h2_w15),
2737e7
+        }
2737e7
+        self.assert_permutations([h1, h2_w15, h3_w10], permutations)
2737e7
+
2737e7
+    def test_large(self):
2737e7
+        records = tuple(
2737e7
+            mksrv(1, i, 443, "host{}".format(i)) for i in range(1000)
2737e7
+        )
2737e7
+        assert len(dnsutil.sort_prio_weight(records)) == len(records)
2737e7
+
2737e7
+
2737e7
+class TestSortURI(object):
2737e7
+    def test_prio(self):
2737e7
+        h1 = mkuri(1, 0, u"https://host1/api")
2737e7
+        h2 = mkuri(2, 0, u"https://host2/api")
2737e7
+        h3 = mkuri(3, 0, u"https://host3/api")
2737e7
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
2737e7
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
2737e7
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
2737e7
-- 
2737e7
2.17.1
2737e7