Blob Blame History Raw
From 9f998164b87cd00dfb7a5750b898a2e17359c31c Mon Sep 17 00:00:00 2001
From: Eric Garver <e@erig.me>
Date: Fri, 19 Oct 2018 09:22:04 -0400
Subject: [PATCH 24/34] nftables: support rich rule priorities

Fixes: #149
Fixes: #224
(cherry picked from commit 25e9a62532d6395ff13665db02aabaa010d6fca5)
---
 src/firewall/core/nftables.py | 214 +++++++++++++++++++++++++++-------
 1 file changed, 174 insertions(+), 40 deletions(-)

diff --git a/src/firewall/core/nftables.py b/src/firewall/core/nftables.py
index a763ed3ec103..d59bc55bf1a5 100644
--- a/src/firewall/core/nftables.py
+++ b/src/firewall/core/nftables.py
@@ -20,6 +20,7 @@
 #
 
 import os.path
+import copy
 
 from firewall.core.base import SHORTCUTS, DEFAULT_ZONE_TARGET
 from firewall.core.prog import runProg
@@ -29,7 +30,8 @@ from firewall.functions import splitArgs, check_mac, portStr, \
 from firewall import config
 from firewall.errors import FirewallError, UNKNOWN_ERROR, INVALID_RULE, \
                             INVALID_ICMPTYPE, INVALID_TYPE, INVALID_ENTRY
-from firewall.core.rich import Rich_Accept, Rich_Reject, Rich_Drop, Rich_Mark
+from firewall.core.rich import Rich_Accept, Rich_Reject, Rich_Drop, Rich_Mark, \
+                               Rich_Masquerade, Rich_ForwardPort, Rich_IcmpBlock
 
 TABLE_NAME = "firewalld"
 
@@ -160,6 +162,7 @@ class nftables(object):
         self.available_tables = []
         self.rule_to_handle = {}
         self.rule_ref_count = {}
+        self.rich_rule_priority_counts = {}
 
     def fill_exists(self):
         self.command_exists = os.path.exists(self._command)
@@ -171,18 +174,11 @@ class nftables(object):
 
         def rule_key_from_rule(rule):
             rule_key = rule[2:]
-            if rule_key[3] in ["position", "handle", "index"]:
-                # strip "position #"
-                # "insert rule family table chain position <num>"
-                #              ^^ rule_key starts here
-                try:
-                    int(rule_key[4])
-                except Exception:
-                    raise FirewallError(INVALID_RULE, "position without a number")
-                else:
-                    rule_key.pop(3)
-                    rule_key.pop(3)
-            return " ".join(rule_key)
+            # "insert rule family table chain index <num>"
+            #              ^^ rule_key starts here
+            if rule_key[3] in ["position", "handle"]:
+                raise FirewallError(INVALID_RULE, "position/handle not allowed in rule")
+            return " ".join([str(x) for x in rule_key])
 
         # If we're deleting a table (i.e. build_flush_rules())
         # then check if its exist first to avoid nft throwing an error
@@ -200,11 +196,6 @@ class nftables(object):
         elif _args[0] in ["delete"] and _args[1] == "rule":
             rule_add = False
             rule_key = rule_key_from_rule(_args)
-            # delete using rule handle
-            _args = ["delete", "rule"] + _args[2:5] + \
-                    ["handle", self.rule_to_handle[rule_key]]
-
-        _args_str = " ".join(_args)
 
         # rule deduplication
         if rule_key in self.rule_ref_count:
@@ -220,22 +211,82 @@ class nftables(object):
                 raise FirewallError(UNKNOWN_ERROR, "rule ref count bug: rule_key '%s', cnt %d"
                                                    % (rule_key, self.rule_ref_count[rule_key]))
             log.debug2("%s: rule ref cnt %d, %s %s", self.__class__,
-                       self.rule_ref_count[rule_key], self._command, _args_str)
+                       self.rule_ref_count[rule_key], self._command,
+                       " ".join([str(x) for x in _args]))
+
+        # replace %%RICH_RULE_PRIORITY%%
+        if rule_key:
+            rich_rule_priority_counts = self.rich_rule_priority_counts
+            try:
+                i = _args.index("%%RICH_RULE_PRIORITY%%")
+            except ValueError:
+                pass
+            else:
+                rich_rule_priority_counts = copy.deepcopy(self.rich_rule_priority_counts)
+                _args.pop(i)
+                priority = _args.pop(i)
+                if type(priority) != int:
+                    raise FirewallError(INVALID_RULE, "rich rule priority must be followed by a number")
+                chain = (_args[2], _args[4]) # family, chain
+                # Add the rule to the priority counts. We don't need to store the
+                # rule, just bump the ref count for the priority value.
+                if not rule_add:
+                    if chain not in rich_rule_priority_counts or \
+                       priority not in rich_rule_priority_counts[chain] or \
+                       rich_rule_priority_counts[chain][priority] <= 0:
+                        raise FirewallError(UNKNOWN_ERROR, "nonexistent or underflow of rich rule priority count")
+
+                    rich_rule_priority_counts[chain][priority] -= 1
+                else:
+                    if chain not in rich_rule_priority_counts:
+                        rich_rule_priority_counts[chain] = {}
+                    if priority not in rich_rule_priority_counts[chain]:
+                        rich_rule_priority_counts[chain][priority] = 0
+
+                    # calculate index of new rule
+                    index = 0
+                    for p in sorted(rich_rule_priority_counts[chain].keys()):
+                        if p == priority and _args[0] == "insert":
+                            break
+                        index += rich_rule_priority_counts[chain][p]
+                        if p == priority and _args[0] == "add":
+                            break
+
+                    rich_rule_priority_counts[chain][priority] += 1
+
+                    if index == 0:
+                        _args[0] = "insert"
+                    else:
+                        index -= 1 # point to the rule before insertion point
+                        _args[0] = "add"
+                        _args.insert(i, "index")
+                        _args.insert(i+1, "%d" % index)
 
         if not rule_key or (not rule_add and self.rule_ref_count[rule_key] == 0) \
                         or (    rule_add and rule_key not in self.rule_ref_count):
+
+            # delete using rule handle
+            if rule_key and not rule_add:
+                _args = ["delete", "rule"] + _args[2:5] + \
+                    ["handle", self.rule_to_handle[rule_key]]
+
+            _args_str = " ".join(_args)
             log.debug2("%s: %s %s", self.__class__, self._command, _args_str)
             (status, output) = runProg(self._command, nft_opts + _args)
             if status != 0:
                 raise ValueError("'%s %s' failed: %s" % (self._command,
                                                          _args_str, output))
+
+            if rule_key:
+                self.rich_rule_priority_counts = rich_rule_priority_counts
+
             # nft requires deleting rules by handle. So we must cache the rule
             # handle when adding/inserting rules.
             #
             if rule_key:
                 if rule_add:
-                    str = "# handle "
-                    offset = output.index(str) + len(str)
+                    handle_str = "# handle "
+                    offset = output.index(handle_str) + len(handle_str)
                     self.rule_to_handle[rule_key] = output[offset:].strip()
                     self.rule_ref_count[rule_key] = 1
                 else:
@@ -305,6 +356,7 @@ class nftables(object):
     def build_flush_rules(self):
         self.rule_to_handle = {}
         self.rule_ref_count = {}
+        self.rich_rule_priority_counts = {}
 
         rules = []
         for family in OUR_CHAINS.keys():
@@ -557,18 +609,27 @@ class nftables(object):
         OUR_CHAINS[family][table].update(set([_zone,
                                          "%s_log" % _zone,
                                          "%s_deny" % _zone,
+                                         "%s_rich_rule_pre" % _zone,
+                                         "%s_rich_rule_post" % _zone,
                                          "%s_allow" % _zone]))
 
         rules = []
         rules.append(["add", "chain", family, "%s" % TABLE_NAME,
                       "%s_%s" % (table, _zone)])
+        rules.append(["add", "chain", family, "%s" % TABLE_NAME,
+                      "%s_%s_rich_rule_pre" % (table, _zone)])
         rules.append(["add", "chain", family, "%s" % TABLE_NAME,
                       "%s_%s_log" % (table, _zone)])
         rules.append(["add", "chain", family, "%s" % TABLE_NAME,
                       "%s_%s_deny" % (table, _zone)])
         rules.append(["add", "chain", family, "%s" % TABLE_NAME,
                       "%s_%s_allow" % (table, _zone)])
+        rules.append(["add", "chain", family, "%s" % TABLE_NAME,
+                      "%s_%s_rich_rule_post" % (table, _zone)])
 
+        rules.append(["add", "rule", family, "%s" % TABLE_NAME,
+                      "%s_%s" % (table, _zone),
+                      "jump", "%s_%s_rich_rule_pre" % (table, _zone)])
         rules.append(["add", "rule", family, "%s" % TABLE_NAME,
                       "%s_%s" % (table, _zone),
                       "jump", "%s_%s_log" % (table, _zone)])
@@ -578,6 +639,9 @@ class nftables(object):
         rules.append(["add", "rule", family, "%s" % TABLE_NAME,
                       "%s_%s" % (table, _zone),
                       "jump", "%s_%s_allow" % (table, _zone)])
+        rules.append(["add", "rule", family, "%s" % TABLE_NAME,
+                      "%s_%s" % (table, _zone),
+                      "jump", "%s_%s_rich_rule_post" % (table, _zone)])
 
         target = self._fw.zone._zones[zone].target
 
@@ -659,14 +723,54 @@ class nftables(object):
         return ["limit", "rate", limit.value[0:i], "/",
                 rich_to_nft[limit.value[i+1]]]
 
+    def _rich_rule_chain_suffix(self, rich_rule):
+        if type(rich_rule.element) in [Rich_Masquerade, Rich_ForwardPort, Rich_IcmpBlock]:
+            # These are special and don't have an explicit action
+            pass
+        elif rich_rule.action:
+            if type(rich_rule.action) not in [Rich_Accept, Rich_Reject, Rich_Drop, Rich_Mark]:
+                raise FirewallError(INVALID_RULE, "Unknown action %s" % type(rich_rule.action))
+        else:
+            raise FirewallError(INVALID_RULE, "No rule action specified.")
+
+        if rich_rule.priority == 0:
+            if type(rich_rule.element) in [Rich_Masquerade, Rich_ForwardPort] or \
+               type(rich_rule.action) in [Rich_Accept, Rich_Mark]:
+                return "allow"
+            elif type(rich_rule.element) in [Rich_IcmpBlock] or \
+                 type(rich_rule.action) in [Rich_Reject, Rich_Drop]:
+                return "deny"
+        elif rich_rule.priority < 0:
+            return "rich_rule_pre"
+        else:
+            return "rich_rule_post"
+
+    def _rich_rule_chain_suffix_from_log(self, rich_rule):
+        if not rich_rule.log and not rich_rule.audit:
+            raise FirewallError(INVALID_RULE, "Not log or audit")
+
+        if rich_rule.priority == 0:
+            return "log"
+        elif rich_rule.priority < 0:
+            return "rich_rule_pre"
+        else:
+            return "rich_rule_post"
+
+    def _rich_rule_priority_fragment(self, rich_rule):
+        if rich_rule.priority == 0:
+            return []
+        return ["%%RICH_RULE_PRIORITY%%", rich_rule.priority]
+
     def _rich_rule_log(self, rich_rule, enable, table, target, rule_fragment):
         if not rich_rule.log:
             return []
 
         add_del = { True: "add", False: "delete" }[enable]
 
+        chain_suffix = self._rich_rule_chain_suffix_from_log(rich_rule)
         rule = [add_del, "rule", "inet", "%s" % TABLE_NAME,
-                "%s_%s_log" % (table, target)]
+                "%s_%s_%s" % (table, target, chain_suffix)]
+        rule += self._rich_rule_priority_fragment(rich_rule)
         rule += rule_fragment + ["log"]
         if rich_rule.log.prefix:
             rule += ["prefix", "\"%s\"" % rich_rule.log.prefix]
@@ -682,8 +786,10 @@ class nftables(object):
 
         add_del = { True: "add", False: "delete" }[enable]
 
+        chain_suffix = self._rich_rule_chain_suffix_from_log(rich_rule)
         rule = [add_del, "rule", "inet", "%s" % TABLE_NAME,
-                "%s_%s_log" % (table, target)]
+                "%s_%s_%s" % (table, target, chain_suffix)]
+        rule += self._rich_rule_priority_fragment(rich_rule)
         rule += rule_fragment + ["log", "level", "audit"]
         rule += self._rich_rule_limit_fragment(rich_rule.audit.limit)
 
@@ -695,28 +801,28 @@ class nftables(object):
 
         add_del = { True: "add", False: "delete" }[enable]
 
+        chain_suffix = self._rich_rule_chain_suffix(rich_rule)
+        chain = "%s_%s_%s" % (table, target, chain_suffix)
         if type(rich_rule.action) == Rich_Accept:
-            chain = "%s_%s_allow" % (table, target)
             rule_action = ["accept"]
         elif type(rich_rule.action) == Rich_Reject:
-            chain = "%s_%s_deny" % (table, target)
             rule_action = ["reject"]
             if rich_rule.action.type:
                 rule_action += self._reject_types_fragment(rich_rule.action.type)
         elif type(rich_rule.action) ==  Rich_Drop:
-            chain = "%s_%s_deny" % (table, target)
             rule_action = ["drop"]
         elif type(rich_rule.action) == Rich_Mark:
             target = DEFAULT_ZONE_TARGET.format(chain=SHORTCUTS["PREROUTING"],
                                                 zone=zone)
             table = "mangle"
-            chain = "%s_%s_allow" % (table, target)
+            chain = "%s_%s_%s" % (table, target, chain_suffix)
             rule_action = ["meta", "mark", "set", rich_rule.action.set]
         else:
             raise FirewallError(INVALID_RULE,
                                 "Unknown action %s" % type(rich_rule.action))
 
         rule = [add_del, "rule", "inet", "%s" % TABLE_NAME, chain]
+        rule += self._rich_rule_priority_fragment(rich_rule)
         rule += rule_fragment
         rule += self._rich_rule_limit_fragment(rich_rule.action.limit)
         rule += rule_action
@@ -902,11 +1008,15 @@ class nftables(object):
 
         rule_fragment = []
         if rich_rule:
+            rule_fragment += self._rich_rule_priority_fragment(rich_rule)
             rule_fragment += self._rich_rule_destination_fragment(rich_rule.destination)
             rule_fragment += self._rich_rule_source_fragment(rich_rule.source)
+            chain_suffix = self._rich_rule_chain_suffix(rich_rule)
+        else:
+            chain_suffix = "allow"
 
         return [[add_del, "rule", family, "%s" % TABLE_NAME,
-                "nat_%s_allow" % (target)]
+                "nat_%s_%s" % (target, chain_suffix)]
                 + rule_fragment + ["oifname", "!=", "lo", "masquerade"]]
 
     def build_zone_masquerade_rules(self, enable, zone, rich_rule=None):
@@ -928,18 +1038,22 @@ class nftables(object):
 
         rule_fragment = []
         if rich_rule:
+            rule_fragment += self._rich_rule_priority_fragment(rich_rule)
             rule_fragment += self._rich_rule_destination_fragment(rich_rule.destination)
             rule_fragment += self._rich_rule_source_fragment(rich_rule.source)
+            chain_suffix = self._rich_rule_chain_suffix(rich_rule)
+        else:
+            chain_suffix = "allow"
 
         rules.append([add_del, "rule", "inet", "%s" % TABLE_NAME,
-                      "filter_%s_allow" % (target)]
+                      "filter_%s_%s" % (target, chain_suffix)]
                       + rule_fragment + ["ct", "state", "new,untracked", "accept"])
 
         return rules
 
     def _build_zone_forward_port_nat_rules(self, enable, zone, protocol,
                                            mark_fragment, toaddr, toport,
-                                           family):
+                                           family, rich_rule=None):
         add_del = { True: "add", False: "delete" }[enable]
         target = DEFAULT_ZONE_TARGET.format(chain=SHORTCUTS["PREROUTING"],
                                             zone=zone)
@@ -953,8 +1067,17 @@ class nftables(object):
         if toport and toport != "":
             dnat_fragment += [":%s" % portStr(toport, "-")]
 
+        rich_rule_priority_fragment = []
+        if rich_rule:
+            rich_rule_priority_fragment += self._rich_rule_priority_fragment(rich_rule)
+            chain_suffix = self._rich_rule_chain_suffix(rich_rule)
+        else:
+            chain_suffix = "allow"
+
         return [[add_del, "rule", family, "%s" % TABLE_NAME,
-                "nat_%s_allow" % (target), "meta", "l4proto", protocol]
+                "nat_%s_%s" % (target, chain_suffix)]
+                + rich_rule_priority_fragment +
+                ["meta", "l4proto", protocol]
                 + mark_fragment + dnat_fragment]
 
     def build_zone_forward_port_rules(self, enable, zone, filter_chain, port,
@@ -968,36 +1091,45 @@ class nftables(object):
                                             zone=zone)
         rule_fragment = []
         if rich_rule:
+            rule_fragment += self._rich_rule_priority_fragment(rich_rule)
             rule_fragment += self._rich_rule_family_fragment(rich_rule.family)
             rule_fragment += self._rich_rule_destination_fragment(rich_rule.destination)
             rule_fragment += self._rich_rule_source_fragment(rich_rule.source)
+            chain_suffix = self._rich_rule_chain_suffix(rich_rule)
+        else:
+            chain_suffix = "allow"
 
         rules = []
         rules.append([add_del, "rule", "inet", "%s" % TABLE_NAME,
-                      "mangle_%s_allow" % (target)]
+                      "mangle_%s_%s" % (target, chain_suffix)]
                       + rule_fragment +
                       [protocol, "dport", port, "meta", "mark", "set", mark_str])
 
         if rich_rule and (rich_rule.family and rich_rule.family == "ipv6"
            or toaddr and check_single_address("ipv6", toaddr)):
             rules.extend(self._build_zone_forward_port_nat_rules(enable, zone,
-                                protocol, mark_fragment, toaddr, toport, "ip6"))
+                                protocol, mark_fragment, toaddr, toport, "ip6", rich_rule))
         elif rich_rule and (rich_rule.family and rich_rule.family == "ipv4"
            or toaddr and check_single_address("ipv4", toaddr)):
             rules.extend(self._build_zone_forward_port_nat_rules(enable, zone,
-                                protocol, mark_fragment, toaddr, toport, "ip"))
+                                protocol, mark_fragment, toaddr, toport, "ip", rich_rule))
         else:
             if not toaddr or check_single_address("ipv6", toaddr):
                 rules.extend(self._build_zone_forward_port_nat_rules(enable, zone,
-                                    protocol, mark_fragment, toaddr, toport, "ip6"))
+                                    protocol, mark_fragment, toaddr, toport, "ip6", rich_rule))
             if not toaddr or check_single_address("ipv4", toaddr):
                 rules.extend(self._build_zone_forward_port_nat_rules(enable, zone,
-                                    protocol, mark_fragment, toaddr, toport, "ip"))
+                                    protocol, mark_fragment, toaddr, toport, "ip", rich_rule))
 
         target = DEFAULT_ZONE_TARGET.format(chain=SHORTCUTS[filter_chain],
                                             zone=zone)
+        rule_fragment = []
+        if rich_rule:
+            rule_fragment += self._rich_rule_priority_fragment(rich_rule)
         rules.append([add_del, "rule", "inet", "%s" % TABLE_NAME,
-                      "filter_%s_allow" % (target), "ct", "state", "new,untracked"]
+                      "filter_%s_%s" % (target, chain_suffix)]
+                      + rule_fragment +
+                      ["ct", "state", "new,untracked"]
                       + mark_fragment + ["accept"])
 
         return rules
@@ -1049,8 +1181,10 @@ class nftables(object):
                     if rich_rule.action:
                         rules.append(self._rich_rule_action(zone, rich_rule, enable, table, target, rule_fragment))
                     else:
+                        chain_suffix = self._rich_rule_chain_suffix(rich_rule)
                         rules.append([add_del, "rule", "inet", "%s" % TABLE_NAME,
-                                      "%s_%s_deny" % (table, target)]
+                                      "%s_%s_%s" % (table, target, chain_suffix)]
+                                      + self._rich_rule_priority_fragment(rich_rule)
                                       + rule_fragment + ["%%REJECT%%"])
                 else:
                     if self._fw.get_log_denied() != "off" and final_target != "accept":
@@ -1079,14 +1213,14 @@ class nftables(object):
 
             # WARN: index must be kept in sync with build_zone_chain_rules()
             rules.append([add_del, "rule", "inet", "%s" % TABLE_NAME,
-                          "%s_%s" % (table, _zone), "index", "2",
+                          "%s_%s" % (table, _zone), "index", "4",
                           "%%ICMP%%", ibi_target])
 
             if self._fw.zone.query_icmp_block_inversion(zone):
                 if self._fw.get_log_denied() != "off":
                     # WARN: index must be kept in sync with build_zone_chain_rules()
                     rules.append([add_del, "rule", "inet", "%s" % TABLE_NAME,
-                                  "%s_%s" % (table, _zone), "index", "2",
+                                  "%s_%s" % (table, _zone), "index", "4",
                                   "%%ICMP%%", "%%LOGTYPE%%", "log", "prefix",
                                   "\"%s_%s_ICMP_BLOCK: \"" % (table, _zone)])
 
-- 
2.18.0