Blob Blame History Raw
From 9e2b0c4b026f281507d1879da8398841270bc20f Mon Sep 17 00:00:00 2001
From: Christian Heimes <cheimes@redhat.com>
Date: Fri, 15 Jun 2018 17:03:29 +0200
Subject: [PATCH] Sort and shuffle SRV record by priority and weight

On multiple occasions, SRV query answers were not properly sorted by
priority. Records with same priority weren't randomized and shuffled.
This caused FreeIPA to contact the same remote peer instead of
distributing the load across all available servers.

Two new helper functions now take care of SRV queries. sort_prio_weight()
sorts SRV and URI records. query_srv() combines SRV lookup with
sort_prio_weight().

Fixes: https://pagure.io/freeipa/issue/7475
Signed-off-by: Christian Heimes <cheimes@redhat.com>
Reviewed-By: Rob Crittenden <rcritten@redhat.com>
---
 ipaclient/install/ipadiscovery.py       |   3 +-
 ipalib/rpc.py                           |  21 ++---
 ipalib/util.py                          |  11 ++-
 ipapython/config.py                     |   8 +-
 ipapython/dnsutil.py                    |  92 +++++++++++++++++++-
 ipaserver/dcerpc.py                     |   4 +-
 ipatests/test_ipapython/test_dnsutil.py | 106 ++++++++++++++++++++++++
 7 files changed, 217 insertions(+), 28 deletions(-)
 create mode 100644 ipatests/test_ipapython/test_dnsutil.py

diff --git a/ipaclient/install/ipadiscovery.py b/ipaclient/install/ipadiscovery.py
index 46e05c971647b4f0fb4e6044ef74aff3e7919632..34142179a9f4957e842769d9d4036d2024130793 100644
--- a/ipaclient/install/ipadiscovery.py
+++ b/ipaclient/install/ipadiscovery.py
@@ -25,6 +25,7 @@ from ipapython.ipa_log_manager import root_logger
 from dns import resolver, rdatatype
 from dns.exception import DNSException
 from ipalib import errors
+from ipapython.dnsutil import query_srv
 from ipapython import ipaldap
 from ipaplatform.paths import paths
 from ipapython.ipautil import valid_ip, realm_to_suffix
@@ -492,7 +493,7 @@ class IPADiscovery(object):
         root_logger.debug("Search DNS for SRV record of %s", qname)
 
         try:
-            answers = resolver.query(qname, rdatatype.SRV)
+            answers = query_srv(qname)
         except DNSException as e:
             root_logger.debug("DNS record not found: %s", e.__class__.__name__)
             answers = []
diff --git a/ipalib/rpc.py b/ipalib/rpc.py
index e3b8d67d69c084ad1a43390b5f93061826a27e1d..e74807d57955cd36aa8622b4441e08ee89cd313e 100644
--- a/ipalib/rpc.py
+++ b/ipalib/rpc.py
@@ -43,7 +43,6 @@ import socket
 import gzip
 
 import gssapi
-from dns import resolver, rdatatype
 from dns.exception import DNSException
 from ssl import SSLError
 import six
@@ -59,7 +58,7 @@ from ipapython.ipa_log_manager import root_logger
 from ipapython import ipautil
 from ipapython import session_storage
 from ipapython.cookie import Cookie
-from ipapython.dnsutil import DNSName
+from ipapython.dnsutil import DNSName, query_srv
 from ipalib.text import _
 from ipalib.util import create_https_connection
 from ipalib.krb_utils import KRB5KDC_ERR_S_PRINCIPAL_UNKNOWN, KRB5KRB_AP_ERR_TKT_EXPIRED, \
@@ -853,7 +852,7 @@ class RPCClient(Connectible):
         name = '_ldap._tcp.%s.' % self.env.domain
 
         try:
-            answers = resolver.query(name, rdatatype.SRV)
+            answers = query_srv(name)
         except DNSException:
             answers = []
 
@@ -861,17 +860,11 @@ class RPCClient(Connectible):
             server = str(answer.target).rstrip(".")
             servers.append('https://%s%s' % (ipautil.format_netloc(server), path))
 
-        servers = list(set(servers))
-        # the list/set conversion won't preserve order so stick in the
-        # local config file version here.
-        cfg_server = rpc_uri
-        if cfg_server in servers:
-            # make sure the configured master server is there just once and
-            # it is the first one
-            servers.remove(cfg_server)
-            servers.insert(0, cfg_server)
-        else:
-            servers.insert(0, cfg_server)
+        # make sure the configured master server is there just once and
+        # it is the first one.
+        if rpc_uri in servers:
+            servers.remove(rpc_uri)
+        servers.insert(0, rpc_uri)
 
         return servers
 
diff --git a/ipalib/util.py b/ipalib/util.py
index 6ee65498b4de674fe4b2ee361541d3bfe648bba0..56db48638e8319859850fba449ed7c23b6e909ab 100644
--- a/ipalib/util.py
+++ b/ipalib/util.py
@@ -934,14 +934,13 @@ def detect_dns_zone_realm_type(api, domain):
 
     try:
         # The presence of this record is enough, return foreign in such case
-        result = resolver.query(ad_specific_record_name, rdatatype.SRV)
-        return 'foreign'
-
+        resolver.query(ad_specific_record_name, rdatatype.SRV)
     except DNSException:
-        pass
+        # If we could not detect type with certainty, return unknown
+        return 'unknown'
+    else:
+        return 'foreign'
 
-    # If we could not detect type with certainity, return unknown
-    return 'unknown'
 
 def has_managed_topology(api):
     domainlevel = api.Command['domainlevel_get']().get('result', DOMAIN_LEVEL_0)
diff --git a/ipapython/config.py b/ipapython/config.py
index 19abfc51ee354d2971be836fa6bad70eea3a6720..44c823b6b946c28a510e5f156061eba0b05aa059 100644
--- a/ipapython/config.py
+++ b/ipapython/config.py
@@ -24,7 +24,6 @@ from optparse import (
 from copy import copy
 import socket
 
-from dns import resolver, rdatatype
 from dns.exception import DNSException
 import dns.name
 # pylint: disable=import-error
@@ -33,6 +32,7 @@ from six.moves.urllib.parse import urlsplit
 # pylint: enable=import-error
 
 from ipapython.dn import DN
+from ipapython.dnsutil import query_srv
 
 try:
     # pylint: disable=ipa-forbidden-import
@@ -195,7 +195,7 @@ def __discover_config(discover_server = True):
             name = "_ldap._tcp." + domain
 
             try:
-                servers = resolver.query(name, rdatatype.SRV)
+                servers = query_srv(name)
             except DNSException:
                 # try cycling on domain components of FQDN
                 try:
@@ -210,7 +210,7 @@ def __discover_config(discover_server = True):
                         return False
                     name = "_ldap._tcp.%s" % domain
                     try:
-                        servers = resolver.query(name, rdatatype.SRV)
+                        servers = query_srv(name)
                         break
                     except DNSException:
                         pass
@@ -221,7 +221,7 @@ def __discover_config(discover_server = True):
             if not servers:
                 name = "_ldap._tcp.%s." % config.default_domain
                 try:
-                    servers = resolver.query(name, rdatatype.SRV)
+                    servers = query_srv(name)
                 except DNSException:
                     pass
 
diff --git a/ipapython/dnsutil.py b/ipapython/dnsutil.py
index 011b722dac3e181ac52f7d92d9f44d31c5e2e6bb..25435ba51e6e7c2c6581b60eb077dd133dd29724 100644
--- a/ipapython/dnsutil.py
+++ b/ipapython/dnsutil.py
@@ -17,10 +17,15 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 
+import copy
+import operator
+import random
+
 import dns.name
 import dns.exception
 import dns.resolver
-import copy
+import dns.rdataclass
+import dns.rdatatype
 
 import six
 
@@ -369,3 +374,88 @@ def check_zone_overlap(zone, raise_on_error=True):
         if ns:
             msg += u" and is handled by server(s): {0}".format(', '.join(ns))
         raise ValueError(msg)
+
+
+def _mix_weight(records):
+    """Weighted population sorting for records with same priority
+    """
+    # trivial case
+    if len(records) <= 1:
+        return records
+
+    # Optimization for common case: If all weights are the same (e.g. 0),
+    # just shuffle the records, which is about four times faster.
+    if all(rr.weight == records[0].weight for rr in records):
+        random.shuffle(records)
+        return records
+
+    noweight = 0.01  # give records with 0 weight a small chance
+    result = []
+    records = set(records)
+    while len(records) > 1:
+        # Compute the sum of the weights of those RRs. Then choose a
+        # uniform random number between 0 and the sum computed (inclusive).
+        urn = random.uniform(0, sum(rr.weight or noweight for rr in records))
+        # Select the RR whose running sum value is the first in the selected
+        # order which is greater than or equal to the random number selected.
+        acc = 0.
+        for rr in records.copy():
+            acc += rr.weight or noweight
+            if acc >= urn:
+                records.remove(rr)
+                result.append(rr)
+    if records:
+        result.append(records.pop())
+    return result
+
+
+def sort_prio_weight(records):
+    """RFC 2782 sorting algorithm for SRV and URI records
+
+    RFC 2782 defines a sorting algorithms for SRV records, that is also used
+    for sorting URI records. Records are sorted by priority and than randomly
+    shuffled according to weight.
+
+    This implementation also removes duplicate entries.
+    """
+    # order records by priority
+    records = sorted(records, key=operator.attrgetter("priority"))
+
+    # remove duplicate entries
+    uniquerecords = []
+    seen = set()
+    for rr in records:
+        # A SRV record has target and port, URI just has target.
+        target = (rr.target, getattr(rr, "port", None))
+        if target not in seen:
+            uniquerecords.append(rr)
+            seen.add(target)
+
+    # weighted randomization of entries with same priority
+    result = []
+    sameprio = []
+    for rr in uniquerecords:
+        # add all items with same priority in a bucket
+        if not sameprio or sameprio[0].priority == rr.priority:
+            sameprio.append(rr)
+        else:
+            # got different priority, shuffle bucket
+            result.extend(_mix_weight(sameprio))
+            # start a new priority list
+            sameprio = [rr]
+    # add last batch of records with same priority
+    if sameprio:
+        result.extend(_mix_weight(sameprio))
+    return result
+
+
+def query_srv(qname, resolver=None, **kwargs):
+    """Query SRV records and sort reply according to RFC 2782
+
+    :param qname: query name, _service._proto.domain.
+    :return: list of dns.rdtypes.IN.SRV.SRV instances
+    """
+    if resolver is None:
+        resolver = dns.resolver
+    answer = resolver.query(qname, rdtype=dns.rdatatype.SRV, **kwargs)
+    return sort_prio_weight(answer)
diff --git a/ipaserver/dcerpc.py b/ipaserver/dcerpc.py
index ac1b2a34784df491a3851aa21bbadbec2297241c..4e957b19292f51a7f6e3540dc38590737c7ae5e4 100644
--- a/ipaserver/dcerpc.py
+++ b/ipaserver/dcerpc.py
@@ -30,6 +30,7 @@ from ipalib import errors
 from ipapython import ipautil
 from ipapython.ipa_log_manager import root_logger
 from ipapython.dn import DN
+from ipapython.dnsutil import query_srv
 from ipaserver.install import installutils
 from ipaserver.dcerpc_common import (TRUST_BIDIRECTIONAL,
                                      TRUST_JOIN_EXTERNAL,
@@ -55,7 +56,6 @@ import samba
 import ldap as _ldap
 from ipapython import ipaldap
 from ipapython.dnsutil import DNSName
-from dns import resolver, rdatatype
 from dns.exception import DNSException
 import pysss_nss_idmap
 import pysss
@@ -795,7 +795,7 @@ class DomainValidator(object):
             gc_name = '_gc._tcp.%s.' % info['dns_domain']
 
             try:
-                answers = resolver.query(gc_name, rdatatype.SRV)
+                answers = query_srv(gc_name)
             except DNSException as e:
                 answers = []
 
diff --git a/ipatests/test_ipapython/test_dnsutil.py b/ipatests/test_ipapython/test_dnsutil.py
new file mode 100644
index 0000000000000000000000000000000000000000..36adb077cf38f6d036aa1048b201dee7d08eb310
--- /dev/null
+++ b/ipatests/test_ipapython/test_dnsutil.py
@@ -0,0 +1,106 @@
+#
+# Copyright (C) 2018  FreeIPA Contributors.  See COPYING for license
+#
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+from dns.rdtypes.IN.SRV import SRV
+from dns.rdtypes.ANY.URI import URI
+
+from ipapython import dnsutil
+
+import pytest
+
+
+def mksrv(priority, weight, port, target):
+    return SRV(
+        rdclass=dns.rdataclass.IN,
+        rdtype=dns.rdatatype.SRV,
+        priority=priority,
+        weight=weight,
+        port=port,
+        target=dns.name.from_text(target)
+    )
+
+
+def mkuri(priority, weight, target):
+    return URI(
+        rdclass=dns.rdataclass.IN,
+        rdtype=dns.rdatatype.URI,
+        priority=priority,
+        weight=weight,
+        target=target
+    )
+
+
+class TestSortSRV(object):
+    def test_empty(self):
+        assert dnsutil.sort_prio_weight([]) == []
+
+    def test_one(self):
+        h1 = mksrv(1, 0, 443, u"host1")
+        assert dnsutil.sort_prio_weight([h1]) == [h1]
+
+        h2 = mksrv(10, 5, 443, u"host2")
+        assert dnsutil.sort_prio_weight([h2]) == [h2]
+
+    def test_prio(self):
+        h1 = mksrv(1, 0, 443, u"host1")
+        h2 = mksrv(2, 0, 443, u"host2")
+        h3 = mksrv(3, 0, 443, u"host3")
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
+
+        h380 = mksrv(4, 0, 80, u"host3")
+        assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380]
+
+        hs = mksrv(-1, 0, 443, u"special")
+        assert dnsutil.sort_prio_weight([h1, h2, hs]) == [hs, h1, h2]
+
+    def assert_permutations(self, answers, permutations):
+        seen = set()
+        for _unused in range(1000):
+            result = tuple(dnsutil.sort_prio_weight(answers))
+            assert result in permutations
+            seen.add(result)
+            if seen == permutations:
+                break
+        else:
+            pytest.fail("sorting didn't exhaust all permutations.")
+
+    def test_sameprio(self):
+        h1 = mksrv(1, 0, 443, u"host1")
+        h2 = mksrv(1, 0, 443, u"host2")
+        permutations = {
+            (h1, h2),
+            (h2, h1),
+        }
+        self.assert_permutations([h1, h2], permutations)
+
+    def test_weight(self):
+        h1 = mksrv(1, 0, 443, u"host1")
+        h2_w15 = mksrv(2, 15, 443, u"host2")
+        h3_w10 = mksrv(2, 10, 443, u"host3")
+
+        permutations = {
+            (h1, h2_w15, h3_w10),
+            (h1, h3_w10, h2_w15),
+        }
+        self.assert_permutations([h1, h2_w15, h3_w10], permutations)
+
+    def test_large(self):
+        records = tuple(
+            mksrv(1, i, 443, "host{}".format(i)) for i in range(1000)
+        )
+        assert len(dnsutil.sort_prio_weight(records)) == len(records)
+
+
+class TestSortURI(object):
+    def test_prio(self):
+        h1 = mkuri(1, 0, u"https://host1/api")
+        h2 = mkuri(2, 0, u"https://host2/api")
+        h3 = mkuri(3, 0, u"https://host3/api")
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
-- 
2.17.1