95ea96
From a56b9ef37dde90a593da6adbae2525048c5c8627 Mon Sep 17 00:00:00 2001
2737e7
From: Christian Heimes <cheimes@redhat.com>
2737e7
Date: Fri, 15 Jun 2018 17:03:29 +0200
2737e7
Subject: [PATCH] Sort and shuffle SRV record by priority and weight
2737e7
2737e7
On multiple occasions, SRV query answers were not properly sorted by
2737e7
priority. Records with same priority weren't randomized and shuffled.
2737e7
This caused FreeIPA to contact the same remote peer instead of
2737e7
distributing the load across all available servers.
2737e7
2737e7
Two new helper functions now take care of SRV queries. sort_prio_weight()
2737e7
sorts SRV and URI records. query_srv() combines SRV lookup with
2737e7
sort_prio_weight().
2737e7
2737e7
Fixes: https://pagure.io/freeipa/issue/7475
2737e7
Signed-off-by: Christian Heimes <cheimes@redhat.com>
2737e7
Reviewed-By: Rob Crittenden <rcritten@redhat.com>
2737e7
---
95ea96
 ipaclient/install/ipadiscovery.py       |   5 +-
95ea96
 ipalib/rpc.py                           |  21 +++----
95ea96
 ipalib/util.py                          |  11 ++--
95ea96
 ipapython/config.py                     |   8 +--
95ea96
 ipapython/dnsutil.py                    |  92 ++++++++++++++++++++++++++-
2737e7
 ipaserver/dcerpc.py                     |   4 +-
95ea96
 ipatests/test_ipapython/test_dnsutil.py | 106 ++++++++++++++++++++++++++++++++
95ea96
 7 files changed, 217 insertions(+), 30 deletions(-)
2737e7
 create mode 100644 ipatests/test_ipapython/test_dnsutil.py
2737e7
2737e7
diff --git a/ipaclient/install/ipadiscovery.py b/ipaclient/install/ipadiscovery.py
95ea96
index 363970c86ba65b7ec49e403adc745fe61b2242aa..a8283d8e6ad070a8b6626a187ff55485bb74eba4 100644
2737e7
--- a/ipaclient/install/ipadiscovery.py
2737e7
+++ b/ipaclient/install/ipadiscovery.py
95ea96
@@ -20,7 +20,6 @@
95ea96
 from __future__ import absolute_import
95ea96
 
95ea96
 import logging
95ea96
-import operator
95ea96
 import socket
95ea96
 
95ea96
 import six
95ea96
@@ -28,6 +27,7 @@ import six
2737e7
 from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 from ipalib import errors
2737e7
+from ipapython.dnsutil import query_srv
2737e7
 from ipapython import ipaldap
2737e7
 from ipaplatform.paths import paths
2737e7
 from ipapython.ipautil import valid_ip, realm_to_suffix
95ea96
@@ -498,8 +498,7 @@ class IPADiscovery(object):
95ea96
         logger.debug("Search DNS for SRV record of %s", qname)
2737e7
 
2737e7
         try:
2737e7
-            answers = resolver.query(qname, rdatatype.SRV)
95ea96
-            answers = sorted(answers, key=operator.attrgetter('priority'))
2737e7
+            answers = query_srv(qname)
2737e7
         except DNSException as e:
95ea96
             logger.debug("DNS record not found: %s", e.__class__.__name__)
2737e7
             answers = []
2737e7
diff --git a/ipalib/rpc.py b/ipalib/rpc.py
95ea96
index c6a8989f5dc157f4c4e59637e6e7d114c5fa952c..17368f160148d6b821ec4934d31b4b4034a7e67b 100644
2737e7
--- a/ipalib/rpc.py
2737e7
+++ b/ipalib/rpc.py
95ea96
@@ -45,7 +45,6 @@ import gzip
95ea96
 from cryptography import x509 as crypto_x509
2737e7
 
2737e7
 import gssapi
2737e7
-from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 from ssl import SSLError
2737e7
 import six
95ea96
@@ -61,7 +60,7 @@ from ipalib.x509 import Encoding as x509_Encoding
2737e7
 from ipapython import ipautil
2737e7
 from ipapython import session_storage
2737e7
 from ipapython.cookie import Cookie
2737e7
-from ipapython.dnsutil import DNSName
2737e7
+from ipapython.dnsutil import DNSName, query_srv
2737e7
 from ipalib.text import _
2737e7
 from ipalib.util import create_https_connection
2737e7
 from ipalib.krb_utils import KRB5KDC_ERR_S_PRINCIPAL_UNKNOWN, KRB5KRB_AP_ERR_TKT_EXPIRED, \
95ea96
@@ -878,7 +877,7 @@ class RPCClient(Connectible):
2737e7
         name = '_ldap._tcp.%s.' % self.env.domain
2737e7
 
2737e7
         try:
2737e7
-            answers = resolver.query(name, rdatatype.SRV)
2737e7
+            answers = query_srv(name)
2737e7
         except DNSException:
2737e7
             answers = []
2737e7
 
95ea96
@@ -886,17 +885,11 @@ class RPCClient(Connectible):
2737e7
             server = str(answer.target).rstrip(".")
2737e7
             servers.append('https://%s%s' % (ipautil.format_netloc(server), path))
2737e7
 
2737e7
-        servers = list(set(servers))
2737e7
-        # the list/set conversion won't preserve order so stick in the
2737e7
-        # local config file version here.
2737e7
-        cfg_server = rpc_uri
2737e7
-        if cfg_server in servers:
2737e7
-            # make sure the configured master server is there just once and
2737e7
-            # it is the first one
2737e7
-            servers.remove(cfg_server)
2737e7
-            servers.insert(0, cfg_server)
2737e7
-        else:
2737e7
-            servers.insert(0, cfg_server)
2737e7
+        # make sure the configured master server is there just once and
2737e7
+        # it is the first one.
2737e7
+        if rpc_uri in servers:
2737e7
+            servers.remove(rpc_uri)
2737e7
+        servers.insert(0, rpc_uri)
2737e7
 
2737e7
         return servers
2737e7
 
2737e7
diff --git a/ipalib/util.py b/ipalib/util.py
95ea96
index ebf6eb3faf91cefc02514afa84ad9b63d0f82b0b..592821f9ff4f9c16fb2589d697e3e202443d0eba 100644
2737e7
--- a/ipalib/util.py
2737e7
+++ b/ipalib/util.py
95ea96
@@ -973,14 +973,13 @@ def detect_dns_zone_realm_type(api, domain):
2737e7
 
2737e7
     try:
2737e7
         # The presence of this record is enough, return foreign in such case
2737e7
-        result = resolver.query(ad_specific_record_name, rdatatype.SRV)
2737e7
-        return 'foreign'
2737e7
-
2737e7
+        resolver.query(ad_specific_record_name, rdatatype.SRV)
2737e7
     except DNSException:
2737e7
-        pass
2737e7
+        # If we could not detect type with certainty, return unknown
2737e7
+        return 'unknown'
2737e7
+    else:
2737e7
+        return 'foreign'
2737e7
 
2737e7
-    # If we could not detect type with certainity, return unknown
2737e7
-    return 'unknown'
2737e7
 
2737e7
 def has_managed_topology(api):
2737e7
     domainlevel = api.Command['domainlevel_get']().get('result', DOMAIN_LEVEL_0)
2737e7
diff --git a/ipapython/config.py b/ipapython/config.py
95ea96
index c3360779f1320e7d41f4d98b317d930a1d0b8242..f701122bd4828693a41848668e49052912042837 100644
2737e7
--- a/ipapython/config.py
2737e7
+++ b/ipapython/config.py
95ea96
@@ -26,7 +26,6 @@ from copy import copy
2737e7
 import socket
95ea96
 import functools
2737e7
 
2737e7
-from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 import dns.name
2737e7
 # pylint: disable=import-error
95ea96
@@ -36,6 +35,7 @@ from six.moves.urllib.parse import urlsplit
2737e7
 
95ea96
 from ipaplatform.paths import paths
2737e7
 from ipapython.dn import DN
2737e7
+from ipapython.dnsutil import query_srv
95ea96
 from ipapython.ipautil import CheckedIPAddress, CheckedIPAddressLoopback
95ea96
 
2737e7
 
95ea96
@@ -210,7 +210,7 @@ def __discover_config(discover_server = True):
2737e7
             name = "_ldap._tcp." + domain
2737e7
 
2737e7
             try:
2737e7
-                servers = resolver.query(name, rdatatype.SRV)
2737e7
+                servers = query_srv(name)
2737e7
             except DNSException:
2737e7
                 # try cycling on domain components of FQDN
2737e7
                 try:
95ea96
@@ -225,7 +225,7 @@ def __discover_config(discover_server = True):
2737e7
                         return False
2737e7
                     name = "_ldap._tcp.%s" % domain
2737e7
                     try:
2737e7
-                        servers = resolver.query(name, rdatatype.SRV)
2737e7
+                        servers = query_srv(name)
2737e7
                         break
2737e7
                     except DNSException:
2737e7
                         pass
95ea96
@@ -236,7 +236,7 @@ def __discover_config(discover_server = True):
2737e7
             if not servers:
2737e7
                 name = "_ldap._tcp.%s." % config.default_domain
2737e7
                 try:
2737e7
-                    servers = resolver.query(name, rdatatype.SRV)
2737e7
+                    servers = query_srv(name)
2737e7
                 except DNSException:
2737e7
                     pass
2737e7
 
2737e7
diff --git a/ipapython/dnsutil.py b/ipapython/dnsutil.py
95ea96
index b40302d0efbb32108626d2cc1216ad4858343407..6157183a0fa5802a5fb772f078ee4c1688857fc8 100644
2737e7
--- a/ipapython/dnsutil.py
2737e7
+++ b/ipapython/dnsutil.py
95ea96
@@ -17,12 +17,17 @@
2737e7
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
2737e7
 #
2737e7
 
2737e7
+import copy
95ea96
 import logging
2737e7
+import operator
2737e7
+import random
95ea96
 
2737e7
 import dns.name
2737e7
 import dns.exception
2737e7
 import dns.resolver
2737e7
-import copy
2737e7
+import dns.rdataclass
2737e7
+import dns.rdatatype
95ea96
+
2737e7
 
2737e7
 import six
2737e7
 
95ea96
@@ -373,3 +378,88 @@ def check_zone_overlap(zone, raise_on_error=True):
2737e7
         if ns:
2737e7
             msg += u" and is handled by server(s): {0}".format(', '.join(ns))
2737e7
         raise ValueError(msg)
2737e7
+
2737e7
+
2737e7
+def _mix_weight(records):
2737e7
+    """Weighted population sorting for records with same priority
2737e7
+    """
2737e7
+    # trivial case
2737e7
+    if len(records) <= 1:
2737e7
+        return records
2737e7
+
2737e7
+    # Optimization for common case: If all weights are the same (e.g. 0),
2737e7
+    # just shuffle the records, which is about four times faster.
2737e7
+    if all(rr.weight == records[0].weight for rr in records):
2737e7
+        random.shuffle(records)
2737e7
+        return records
2737e7
+
2737e7
+    noweight = 0.01  # give records with 0 weight a small chance
2737e7
+    result = []
2737e7
+    records = set(records)
2737e7
+    while len(records) > 1:
2737e7
+        # Compute the sum of the weights of those RRs. Then choose a
2737e7
+        # uniform random number between 0 and the sum computed (inclusive).
2737e7
+        urn = random.uniform(0, sum(rr.weight or noweight for rr in records))
2737e7
+        # Select the RR whose running sum value is the first in the selected
2737e7
+        # order which is greater than or equal to the random number selected.
2737e7
+        acc = 0.
2737e7
+        for rr in records.copy():
2737e7
+            acc += rr.weight or noweight
2737e7
+            if acc >= urn:
2737e7
+                records.remove(rr)
2737e7
+                result.append(rr)
2737e7
+    if records:
2737e7
+        result.append(records.pop())
2737e7
+    return result
2737e7
+
2737e7
+
2737e7
+def sort_prio_weight(records):
2737e7
+    """RFC 2782 sorting algorithm for SRV and URI records
2737e7
+
2737e7
+    RFC 2782 defines a sorting algorithms for SRV records, that is also used
2737e7
+    for sorting URI records. Records are sorted by priority and than randomly
2737e7
+    shuffled according to weight.
2737e7
+
2737e7
+    This implementation also removes duplicate entries.
2737e7
+    """
2737e7
+    # order records by priority
2737e7
+    records = sorted(records, key=operator.attrgetter("priority"))
2737e7
+
2737e7
+    # remove duplicate entries
2737e7
+    uniquerecords = []
2737e7
+    seen = set()
2737e7
+    for rr in records:
2737e7
+        # A SRV record has target and port, URI just has target.
2737e7
+        target = (rr.target, getattr(rr, "port", None))
2737e7
+        if target not in seen:
2737e7
+            uniquerecords.append(rr)
2737e7
+            seen.add(target)
2737e7
+
2737e7
+    # weighted randomization of entries with same priority
2737e7
+    result = []
2737e7
+    sameprio = []
2737e7
+    for rr in uniquerecords:
2737e7
+        # add all items with same priority in a bucket
2737e7
+        if not sameprio or sameprio[0].priority == rr.priority:
2737e7
+            sameprio.append(rr)
2737e7
+        else:
2737e7
+            # got different priority, shuffle bucket
2737e7
+            result.extend(_mix_weight(sameprio))
2737e7
+            # start a new priority list
2737e7
+            sameprio = [rr]
2737e7
+    # add last batch of records with same priority
2737e7
+    if sameprio:
2737e7
+        result.extend(_mix_weight(sameprio))
2737e7
+    return result
2737e7
+
2737e7
+
2737e7
+def query_srv(qname, resolver=None, **kwargs):
2737e7
+    """Query SRV records and sort reply according to RFC 2782
2737e7
+
2737e7
+    :param qname: query name, _service._proto.domain.
2737e7
+    :return: list of dns.rdtypes.IN.SRV.SRV instances
2737e7
+    """
2737e7
+    if resolver is None:
2737e7
+        resolver = dns.resolver
2737e7
+    answer = resolver.query(qname, rdtype=dns.rdatatype.SRV, **kwargs)
2737e7
+    return sort_prio_weight(answer)
2737e7
diff --git a/ipaserver/dcerpc.py b/ipaserver/dcerpc.py
95ea96
index e3aa9f6a6c239f8606f09bde06f21cb9cf64eb14..1e2d6fbbb5372bea7c6d68b73acededfe714a56a 100644
2737e7
--- a/ipaserver/dcerpc.py
2737e7
+++ b/ipaserver/dcerpc.py
95ea96
@@ -32,6 +32,7 @@ from ipalib import api, _
95ea96
 from ipalib import errors
2737e7
 from ipapython import ipautil
2737e7
 from ipapython.dn import DN
2737e7
+from ipapython.dnsutil import query_srv
95ea96
 from ipapython.ipaldap import ldap_initialize
2737e7
 from ipaserver.install import installutils
2737e7
 from ipaserver.dcerpc_common import (TRUST_BIDIRECTIONAL,
2737e7
@@ -55,7 +56,6 @@ import samba
2737e7
 import ldap as _ldap
2737e7
 from ipapython import ipaldap
2737e7
 from ipapython.dnsutil import DNSName
2737e7
-from dns import resolver, rdatatype
2737e7
 from dns.exception import DNSException
2737e7
 import pysss_nss_idmap
2737e7
 import pysss
95ea96
@@ -802,7 +802,7 @@ class DomainValidator(object):
2737e7
             gc_name = '_gc._tcp.%s.' % info['dns_domain']
2737e7
 
2737e7
             try:
2737e7
-                answers = resolver.query(gc_name, rdatatype.SRV)
2737e7
+                answers = query_srv(gc_name)
2737e7
             except DNSException as e:
2737e7
                 answers = []
2737e7
 
2737e7
diff --git a/ipatests/test_ipapython/test_dnsutil.py b/ipatests/test_ipapython/test_dnsutil.py
2737e7
new file mode 100644
2737e7
index 0000000000000000000000000000000000000000..36adb077cf38f6d036aa1048b201dee7d08eb310
2737e7
--- /dev/null
2737e7
+++ b/ipatests/test_ipapython/test_dnsutil.py
2737e7
@@ -0,0 +1,106 @@
2737e7
+#
2737e7
+# Copyright (C) 2018  FreeIPA Contributors.  See COPYING for license
2737e7
+#
2737e7
+import dns.name
2737e7
+import dns.rdataclass
2737e7
+import dns.rdatatype
2737e7
+from dns.rdtypes.IN.SRV import SRV
2737e7
+from dns.rdtypes.ANY.URI import URI
2737e7
+
2737e7
+from ipapython import dnsutil
2737e7
+
2737e7
+import pytest
2737e7
+
2737e7
+
2737e7
+def mksrv(priority, weight, port, target):
2737e7
+    return SRV(
2737e7
+        rdclass=dns.rdataclass.IN,
2737e7
+        rdtype=dns.rdatatype.SRV,
2737e7
+        priority=priority,
2737e7
+        weight=weight,
2737e7
+        port=port,
2737e7
+        target=dns.name.from_text(target)
2737e7
+    )
2737e7
+
2737e7
+
2737e7
+def mkuri(priority, weight, target):
2737e7
+    return URI(
2737e7
+        rdclass=dns.rdataclass.IN,
2737e7
+        rdtype=dns.rdatatype.URI,
2737e7
+        priority=priority,
2737e7
+        weight=weight,
2737e7
+        target=target
2737e7
+    )
2737e7
+
2737e7
+
2737e7
+class TestSortSRV(object):
2737e7
+    def test_empty(self):
2737e7
+        assert dnsutil.sort_prio_weight([]) == []
2737e7
+
2737e7
+    def test_one(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        assert dnsutil.sort_prio_weight([h1]) == [h1]
2737e7
+
2737e7
+        h2 = mksrv(10, 5, 443, u"host2")
2737e7
+        assert dnsutil.sort_prio_weight([h2]) == [h2]
2737e7
+
2737e7
+    def test_prio(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        h2 = mksrv(2, 0, 443, u"host2")
2737e7
+        h3 = mksrv(3, 0, 443, u"host3")
2737e7
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
2737e7
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
2737e7
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
2737e7
+
2737e7
+        h380 = mksrv(4, 0, 80, u"host3")
2737e7
+        assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380]
2737e7
+
2737e7
+        hs = mksrv(-1, 0, 443, u"special")
2737e7
+        assert dnsutil.sort_prio_weight([h1, h2, hs]) == [hs, h1, h2]
2737e7
+
2737e7
+    def assert_permutations(self, answers, permutations):
2737e7
+        seen = set()
2737e7
+        for _unused in range(1000):
2737e7
+            result = tuple(dnsutil.sort_prio_weight(answers))
2737e7
+            assert result in permutations
2737e7
+            seen.add(result)
2737e7
+            if seen == permutations:
2737e7
+                break
2737e7
+        else:
2737e7
+            pytest.fail("sorting didn't exhaust all permutations.")
2737e7
+
2737e7
+    def test_sameprio(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        h2 = mksrv(1, 0, 443, u"host2")
2737e7
+        permutations = {
2737e7
+            (h1, h2),
2737e7
+            (h2, h1),
2737e7
+        }
2737e7
+        self.assert_permutations([h1, h2], permutations)
2737e7
+
2737e7
+    def test_weight(self):
2737e7
+        h1 = mksrv(1, 0, 443, u"host1")
2737e7
+        h2_w15 = mksrv(2, 15, 443, u"host2")
2737e7
+        h3_w10 = mksrv(2, 10, 443, u"host3")
2737e7
+
2737e7
+        permutations = {
2737e7
+            (h1, h2_w15, h3_w10),
2737e7
+            (h1, h3_w10, h2_w15),
2737e7
+        }
2737e7
+        self.assert_permutations([h1, h2_w15, h3_w10], permutations)
2737e7
+
2737e7
+    def test_large(self):
2737e7
+        records = tuple(
2737e7
+            mksrv(1, i, 443, "host{}".format(i)) for i in range(1000)
2737e7
+        )
2737e7
+        assert len(dnsutil.sort_prio_weight(records)) == len(records)
2737e7
+
2737e7
+
2737e7
+class TestSortURI(object):
2737e7
+    def test_prio(self):
2737e7
+        h1 = mkuri(1, 0, u"https://host1/api")
2737e7
+        h2 = mkuri(2, 0, u"https://host2/api")
2737e7
+        h3 = mkuri(3, 0, u"https://host3/api")
2737e7
+        assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
2737e7
+        assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
2737e7
+        assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
2737e7
-- 
95ea96
2.14.4
2737e7