From a56b9ef37dde90a593da6adbae2525048c5c8627 Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Fri, 15 Jun 2018 17:03:29 +0200 Subject: [PATCH] Sort and shuffle SRV record by priority and weight On multiple occasions, SRV query answers were not properly sorted by priority. Records with same priority weren't randomized and shuffled. This caused FreeIPA to contact the same remote peer instead of distributing the load across all available servers. Two new helper functions now take care of SRV queries. sort_prio_weight() sorts SRV and URI records. query_srv() combines SRV lookup with sort_prio_weight(). Fixes: https://pagure.io/freeipa/issue/7475 Signed-off-by: Christian Heimes Reviewed-By: Rob Crittenden --- ipaclient/install/ipadiscovery.py | 5 +- ipalib/rpc.py | 21 +++---- ipalib/util.py | 11 ++-- ipapython/config.py | 8 +-- ipapython/dnsutil.py | 92 ++++++++++++++++++++++++++- ipaserver/dcerpc.py | 4 +- ipatests/test_ipapython/test_dnsutil.py | 106 ++++++++++++++++++++++++++++++++ 7 files changed, 217 insertions(+), 30 deletions(-) create mode 100644 ipatests/test_ipapython/test_dnsutil.py diff --git a/ipaclient/install/ipadiscovery.py b/ipaclient/install/ipadiscovery.py index 363970c86ba65b7ec49e403adc745fe61b2242aa..a8283d8e6ad070a8b6626a187ff55485bb74eba4 100644 --- a/ipaclient/install/ipadiscovery.py +++ b/ipaclient/install/ipadiscovery.py @@ -20,7 +20,6 @@ from __future__ import absolute_import import logging -import operator import socket import six @@ -28,6 +27,7 @@ import six from dns import resolver, rdatatype from dns.exception import DNSException from ipalib import errors +from ipapython.dnsutil import query_srv from ipapython import ipaldap from ipaplatform.paths import paths from ipapython.ipautil import valid_ip, realm_to_suffix @@ -498,8 +498,7 @@ class IPADiscovery(object): logger.debug("Search DNS for SRV record of %s", qname) try: - answers = resolver.query(qname, rdatatype.SRV) - answers = sorted(answers, key=operator.attrgetter('priority')) + answers = query_srv(qname) except DNSException as e: logger.debug("DNS record not found: %s", e.__class__.__name__) answers = [] diff --git a/ipalib/rpc.py b/ipalib/rpc.py index c6a8989f5dc157f4c4e59637e6e7d114c5fa952c..17368f160148d6b821ec4934d31b4b4034a7e67b 100644 --- a/ipalib/rpc.py +++ b/ipalib/rpc.py @@ -45,7 +45,6 @@ import gzip from cryptography import x509 as crypto_x509 import gssapi -from dns import resolver, rdatatype from dns.exception import DNSException from ssl import SSLError import six @@ -61,7 +60,7 @@ from ipalib.x509 import Encoding as x509_Encoding from ipapython import ipautil from ipapython import session_storage from ipapython.cookie import Cookie -from ipapython.dnsutil import DNSName +from ipapython.dnsutil import DNSName, query_srv from ipalib.text import _ from ipalib.util import create_https_connection from ipalib.krb_utils import KRB5KDC_ERR_S_PRINCIPAL_UNKNOWN, KRB5KRB_AP_ERR_TKT_EXPIRED, \ @@ -878,7 +877,7 @@ class RPCClient(Connectible): name = '_ldap._tcp.%s.' % self.env.domain try: - answers = resolver.query(name, rdatatype.SRV) + answers = query_srv(name) except DNSException: answers = [] @@ -886,17 +885,11 @@ class RPCClient(Connectible): server = str(answer.target).rstrip(".") servers.append('https://%s%s' % (ipautil.format_netloc(server), path)) - servers = list(set(servers)) - # the list/set conversion won't preserve order so stick in the - # local config file version here. - cfg_server = rpc_uri - if cfg_server in servers: - # make sure the configured master server is there just once and - # it is the first one - servers.remove(cfg_server) - servers.insert(0, cfg_server) - else: - servers.insert(0, cfg_server) + # make sure the configured master server is there just once and + # it is the first one. + if rpc_uri in servers: + servers.remove(rpc_uri) + servers.insert(0, rpc_uri) return servers diff --git a/ipalib/util.py b/ipalib/util.py index ebf6eb3faf91cefc02514afa84ad9b63d0f82b0b..592821f9ff4f9c16fb2589d697e3e202443d0eba 100644 --- a/ipalib/util.py +++ b/ipalib/util.py @@ -973,14 +973,13 @@ def detect_dns_zone_realm_type(api, domain): try: # The presence of this record is enough, return foreign in such case - result = resolver.query(ad_specific_record_name, rdatatype.SRV) - return 'foreign' - + resolver.query(ad_specific_record_name, rdatatype.SRV) except DNSException: - pass + # If we could not detect type with certainty, return unknown + return 'unknown' + else: + return 'foreign' - # If we could not detect type with certainity, return unknown - return 'unknown' def has_managed_topology(api): domainlevel = api.Command['domainlevel_get']().get('result', DOMAIN_LEVEL_0) diff --git a/ipapython/config.py b/ipapython/config.py index c3360779f1320e7d41f4d98b317d930a1d0b8242..f701122bd4828693a41848668e49052912042837 100644 --- a/ipapython/config.py +++ b/ipapython/config.py @@ -26,7 +26,6 @@ from copy import copy import socket import functools -from dns import resolver, rdatatype from dns.exception import DNSException import dns.name # pylint: disable=import-error @@ -36,6 +35,7 @@ from six.moves.urllib.parse import urlsplit from ipaplatform.paths import paths from ipapython.dn import DN +from ipapython.dnsutil import query_srv from ipapython.ipautil import CheckedIPAddress, CheckedIPAddressLoopback @@ -210,7 +210,7 @@ def __discover_config(discover_server = True): name = "_ldap._tcp." + domain try: - servers = resolver.query(name, rdatatype.SRV) + servers = query_srv(name) except DNSException: # try cycling on domain components of FQDN try: @@ -225,7 +225,7 @@ def __discover_config(discover_server = True): return False name = "_ldap._tcp.%s" % domain try: - servers = resolver.query(name, rdatatype.SRV) + servers = query_srv(name) break except DNSException: pass @@ -236,7 +236,7 @@ def __discover_config(discover_server = True): if not servers: name = "_ldap._tcp.%s." % config.default_domain try: - servers = resolver.query(name, rdatatype.SRV) + servers = query_srv(name) except DNSException: pass diff --git a/ipapython/dnsutil.py b/ipapython/dnsutil.py index b40302d0efbb32108626d2cc1216ad4858343407..6157183a0fa5802a5fb772f078ee4c1688857fc8 100644 --- a/ipapython/dnsutil.py +++ b/ipapython/dnsutil.py @@ -17,12 +17,17 @@ # along with this program. If not, see . # +import copy import logging +import operator +import random import dns.name import dns.exception import dns.resolver -import copy +import dns.rdataclass +import dns.rdatatype + import six @@ -373,3 +378,88 @@ def check_zone_overlap(zone, raise_on_error=True): if ns: msg += u" and is handled by server(s): {0}".format(', '.join(ns)) raise ValueError(msg) + + +def _mix_weight(records): + """Weighted population sorting for records with same priority + """ + # trivial case + if len(records) <= 1: + return records + + # Optimization for common case: If all weights are the same (e.g. 0), + # just shuffle the records, which is about four times faster. + if all(rr.weight == records[0].weight for rr in records): + random.shuffle(records) + return records + + noweight = 0.01 # give records with 0 weight a small chance + result = [] + records = set(records) + while len(records) > 1: + # Compute the sum of the weights of those RRs. Then choose a + # uniform random number between 0 and the sum computed (inclusive). + urn = random.uniform(0, sum(rr.weight or noweight for rr in records)) + # Select the RR whose running sum value is the first in the selected + # order which is greater than or equal to the random number selected. + acc = 0. + for rr in records.copy(): + acc += rr.weight or noweight + if acc >= urn: + records.remove(rr) + result.append(rr) + if records: + result.append(records.pop()) + return result + + +def sort_prio_weight(records): + """RFC 2782 sorting algorithm for SRV and URI records + + RFC 2782 defines a sorting algorithms for SRV records, that is also used + for sorting URI records. Records are sorted by priority and than randomly + shuffled according to weight. + + This implementation also removes duplicate entries. + """ + # order records by priority + records = sorted(records, key=operator.attrgetter("priority")) + + # remove duplicate entries + uniquerecords = [] + seen = set() + for rr in records: + # A SRV record has target and port, URI just has target. + target = (rr.target, getattr(rr, "port", None)) + if target not in seen: + uniquerecords.append(rr) + seen.add(target) + + # weighted randomization of entries with same priority + result = [] + sameprio = [] + for rr in uniquerecords: + # add all items with same priority in a bucket + if not sameprio or sameprio[0].priority == rr.priority: + sameprio.append(rr) + else: + # got different priority, shuffle bucket + result.extend(_mix_weight(sameprio)) + # start a new priority list + sameprio = [rr] + # add last batch of records with same priority + if sameprio: + result.extend(_mix_weight(sameprio)) + return result + + +def query_srv(qname, resolver=None, **kwargs): + """Query SRV records and sort reply according to RFC 2782 + + :param qname: query name, _service._proto.domain. + :return: list of dns.rdtypes.IN.SRV.SRV instances + """ + if resolver is None: + resolver = dns.resolver + answer = resolver.query(qname, rdtype=dns.rdatatype.SRV, **kwargs) + return sort_prio_weight(answer) diff --git a/ipaserver/dcerpc.py b/ipaserver/dcerpc.py index e3aa9f6a6c239f8606f09bde06f21cb9cf64eb14..1e2d6fbbb5372bea7c6d68b73acededfe714a56a 100644 --- a/ipaserver/dcerpc.py +++ b/ipaserver/dcerpc.py @@ -32,6 +32,7 @@ from ipalib import api, _ from ipalib import errors from ipapython import ipautil from ipapython.dn import DN +from ipapython.dnsutil import query_srv from ipapython.ipaldap import ldap_initialize from ipaserver.install import installutils from ipaserver.dcerpc_common import (TRUST_BIDIRECTIONAL, @@ -55,7 +56,6 @@ import samba import ldap as _ldap from ipapython import ipaldap from ipapython.dnsutil import DNSName -from dns import resolver, rdatatype from dns.exception import DNSException import pysss_nss_idmap import pysss @@ -802,7 +802,7 @@ class DomainValidator(object): gc_name = '_gc._tcp.%s.' % info['dns_domain'] try: - answers = resolver.query(gc_name, rdatatype.SRV) + answers = query_srv(gc_name) except DNSException as e: answers = [] diff --git a/ipatests/test_ipapython/test_dnsutil.py b/ipatests/test_ipapython/test_dnsutil.py new file mode 100644 index 0000000000000000000000000000000000000000..36adb077cf38f6d036aa1048b201dee7d08eb310 --- /dev/null +++ b/ipatests/test_ipapython/test_dnsutil.py @@ -0,0 +1,106 @@ +# +# Copyright (C) 2018 FreeIPA Contributors. See COPYING for license +# +import dns.name +import dns.rdataclass +import dns.rdatatype +from dns.rdtypes.IN.SRV import SRV +from dns.rdtypes.ANY.URI import URI + +from ipapython import dnsutil + +import pytest + + +def mksrv(priority, weight, port, target): + return SRV( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.SRV, + priority=priority, + weight=weight, + port=port, + target=dns.name.from_text(target) + ) + + +def mkuri(priority, weight, target): + return URI( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.URI, + priority=priority, + weight=weight, + target=target + ) + + +class TestSortSRV(object): + def test_empty(self): + assert dnsutil.sort_prio_weight([]) == [] + + def test_one(self): + h1 = mksrv(1, 0, 443, u"host1") + assert dnsutil.sort_prio_weight([h1]) == [h1] + + h2 = mksrv(10, 5, 443, u"host2") + assert dnsutil.sort_prio_weight([h2]) == [h2] + + def test_prio(self): + h1 = mksrv(1, 0, 443, u"host1") + h2 = mksrv(2, 0, 443, u"host2") + h3 = mksrv(3, 0, 443, u"host3") + assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3] + assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3] + assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2] + + h380 = mksrv(4, 0, 80, u"host3") + assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380] + + hs = mksrv(-1, 0, 443, u"special") + assert dnsutil.sort_prio_weight([h1, h2, hs]) == [hs, h1, h2] + + def assert_permutations(self, answers, permutations): + seen = set() + for _unused in range(1000): + result = tuple(dnsutil.sort_prio_weight(answers)) + assert result in permutations + seen.add(result) + if seen == permutations: + break + else: + pytest.fail("sorting didn't exhaust all permutations.") + + def test_sameprio(self): + h1 = mksrv(1, 0, 443, u"host1") + h2 = mksrv(1, 0, 443, u"host2") + permutations = { + (h1, h2), + (h2, h1), + } + self.assert_permutations([h1, h2], permutations) + + def test_weight(self): + h1 = mksrv(1, 0, 443, u"host1") + h2_w15 = mksrv(2, 15, 443, u"host2") + h3_w10 = mksrv(2, 10, 443, u"host3") + + permutations = { + (h1, h2_w15, h3_w10), + (h1, h3_w10, h2_w15), + } + self.assert_permutations([h1, h2_w15, h3_w10], permutations) + + def test_large(self): + records = tuple( + mksrv(1, i, 443, "host{}".format(i)) for i in range(1000) + ) + assert len(dnsutil.sort_prio_weight(records)) == len(records) + + +class TestSortURI(object): + def test_prio(self): + h1 = mkuri(1, 0, u"https://host1/api") + h2 = mkuri(2, 0, u"https://host2/api") + h3 = mkuri(3, 0, u"https://host3/api") + assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3] + assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3] + assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2] -- 2.14.4