345f2a
commit b003a0da7865caa25b5d1e70c79329b32409b02a (HEAD -> refs/heads/v4, refs/remotes/origin/v4)
345f2a
Author: Amos Jeffries <yadij@users.noreply.github.com>
345f2a
Date:   2021-09-24 21:53:11 +0000
345f2a
345f2a
    WCCP: Validate packets better (#899)
345f2a
    
345f2a
    Update WCCP to support exception based error handling for
345f2a
    parsing and processing we are moving Squid to for protocol
345f2a
    handling.
345f2a
    
345f2a
    Update the main WCCPv2 parsing checks to throw meaningful
345f2a
    exceptions when detected.
345f2a
345f2a
diff --git a/src/wccp2.cc b/src/wccp2.cc
345f2a
index ee592449c..6ef469e91 100644
345f2a
--- a/src/wccp2.cc
345f2a
+++ b/src/wccp2.cc
345f2a
@@ -1108,6 +1108,59 @@ wccp2ConnectionClose(void)
345f2a
  * Functions for handling the requests.
345f2a
  */
345f2a
 
345f2a
+/// Checks that the given area section ends inside the given (whole) area.
345f2a
+/// \param error the message to throw when the section does not fit
345f2a
+static void
345f2a
+CheckSectionLength(const void *sectionStart, const size_t sectionLength, const void *wholeStart, const size_t wholeSize, const char *error)
345f2a
+{
345f2a
+    assert(sectionStart);
345f2a
+    assert(wholeStart);
345f2a
+
345f2a
+    const auto wholeEnd = static_cast<const char*>(wholeStart) + wholeSize;
345f2a
+    assert(sectionStart >= wholeStart && "we never go backwards");
345f2a
+    assert(sectionStart <= wholeEnd && "we never go beyond our whole (but zero-sized fields are OK)");
345f2a
+    static_assert(sizeof(wccp2_i_see_you_t) <= PTRDIFF_MAX, "paranoid: no UB when subtracting in-whole pointers");
345f2a
+    // subtraction safe due to the three assertions above
345f2a
+    const auto remainderDiff = wholeEnd - static_cast<const char*>(sectionStart);
345f2a
+
345f2a
+    // casting safe due to the assertions above (and size_t definition)
345f2a
+    assert(remainderDiff >= 0);
345f2a
+    const auto remainderSize = static_cast<size_t>(remainderDiff);
345f2a
+
345f2a
+    if (sectionLength <= remainderSize)
345f2a
+        return;
345f2a
+
345f2a
+    throw TextException(error, Here());
345f2a
+}
345f2a
+
345f2a
+/// Checks that the area contains at least dataLength bytes after the header.
345f2a
+/// The size of the field header itself is not included in dataLength.
345f2a
+/// \returns the total field size -- the field header and field data combined
345f2a
+template<class FieldHeader>
345f2a
+static size_t
345f2a
+CheckFieldDataLength(const FieldHeader *header, const size_t dataLength, const void *areaStart, const size_t areaSize, const char *error)
345f2a
+{
345f2a
+    assert(header);
345f2a
+    const auto dataStart = reinterpret_cast<const char*>(header) + sizeof(header);
345f2a
+    CheckSectionLength(dataStart, dataLength, areaStart, areaSize, error);
345f2a
+    return sizeof(header) + dataLength; // no overflow after CheckSectionLength()
345f2a
+}
345f2a
+
345f2a
+/// Positions the given field at a given start within a given packet area.
345f2a
+/// The Field type determines the correct field size (used for bounds checking).
345f2a
+/// \param field the field pointer the function should set
345f2a
+/// \param areaStart the start of a packet (sub)structure containing the field
345f2a
+/// \param areaSize the size of the packet (sub)structure starting at areaStart
345f2a
+/// \param fieldStart the start of a field within the given area
345f2a
+/// \param error the message to throw when the field does not fit the area
345f2a
+template<class Field>
345f2a
+static void
345f2a
+SetField(Field *&field, const void *fieldStart, const void *areaStart, const size_t areaSize, const char *error)
345f2a
+{
345f2a
+    CheckSectionLength(fieldStart, sizeof(Field), areaStart, areaSize, error);
345f2a
+    field = static_cast<Field*>(const_cast<void*>(fieldStart));
345f2a
+}
345f2a
+
345f2a
 /*
345f2a
  * Accept the UDP packet
345f2a
  */
345f2a
@@ -1124,8 +1177,6 @@ wccp2HandleUdp(int sock, void *)
345f2a
 
345f2a
     /* These structs form the parts of the packet */
345f2a
 
345f2a
-    struct wccp2_item_header_t *header = NULL;
345f2a
-
345f2a
     struct wccp2_security_none_t *security_info = NULL;
345f2a
 
345f2a
     struct wccp2_service_info_t *service_info = NULL;
345f2a
@@ -1141,14 +1192,13 @@ wccp2HandleUdp(int sock, void *)
345f2a
     struct wccp2_cache_identity_info_t *cache_identity = NULL;
345f2a
 
345f2a
     struct wccp2_capability_info_header_t *router_capability_header = NULL;
345f2a
+    char *router_capability_data_start = nullptr;
345f2a
 
345f2a
     struct wccp2_capability_element_t *router_capability_element;
345f2a
 
345f2a
     struct sockaddr_in from;
345f2a
 
345f2a
     struct in_addr cache_address;
345f2a
-    int len, found;
345f2a
-    short int data_length, offset;
345f2a
     uint32_t tmp;
345f2a
     char *ptr;
345f2a
     int num_caches;
345f2a
@@ -1161,20 +1211,18 @@ wccp2HandleUdp(int sock, void *)
345f2a
     Ip::Address from_tmp;
345f2a
     from_tmp.setIPv4();
345f2a
 
345f2a
-    len = comm_udp_recvfrom(sock,
345f2a
-                            &wccp2_i_see_you,
345f2a
-                            WCCP_RESPONSE_SIZE,
345f2a
-                            0,
345f2a
-                            from_tmp);
345f2a
+    const auto lenOrError = comm_udp_recvfrom(sock, &wccp2_i_see_you, WCCP_RESPONSE_SIZE, 0, from_tmp);
345f2a
 
345f2a
-    if (len < 0)
345f2a
+    if (lenOrError < 0)
345f2a
         return;
345f2a
+    const auto len = static_cast<size_t>(lenOrError);
345f2a
 
345f2a
-    if (ntohs(wccp2_i_see_you.version) != WCCP2_VERSION)
345f2a
-        return;
345f2a
-
345f2a
-    if (ntohl(wccp2_i_see_you.type) != WCCP2_I_SEE_YOU)
345f2a
-        return;
345f2a
+    try {
345f2a
+        // TODO: Remove wccp2_i_see_you.data and use a buffer to read messages.
345f2a
+        const auto message_header_size = sizeof(wccp2_i_see_you) - sizeof(wccp2_i_see_you.data);
345f2a
+        Must2(len >= message_header_size, "incomplete WCCP message header");
345f2a
+        Must2(ntohs(wccp2_i_see_you.version) == WCCP2_VERSION, "WCCP version unsupported");
345f2a
+        Must2(ntohl(wccp2_i_see_you.type) == WCCP2_I_SEE_YOU, "WCCP packet type unsupported");
345f2a
 
345f2a
     /* FIXME INET6 : drop conversion boundary */
345f2a
     from_tmp.getSockAddr(from);
345f2a
@@ -1182,73 +1230,60 @@ wccp2HandleUdp(int sock, void *)
345f2a
     debugs(80, 3, "Incoming WCCPv2 I_SEE_YOU length " << ntohs(wccp2_i_see_you.length) << ".");
345f2a
 
345f2a
     /* Record the total data length */
345f2a
-    data_length = ntohs(wccp2_i_see_you.length);
345f2a
+    const auto data_length = ntohs(wccp2_i_see_you.length);
345f2a
+    Must2(data_length <= len - message_header_size,
345f2a
+          "malformed packet claiming it's bigger than received data");
345f2a
 
345f2a
-    offset = 0;
345f2a
-
345f2a
-    if (data_length > len) {
345f2a
-        debugs(80, DBG_IMPORTANT, "ERROR: Malformed WCCPv2 packet claiming it's bigger than received data");
345f2a
-        return;
345f2a
-    }
345f2a
+    size_t offset = 0;
345f2a
 
345f2a
     /* Go through the data structure */
345f2a
-    while (data_length > offset) {
345f2a
+    while (offset + sizeof(struct wccp2_item_header_t) <= data_length) {
345f2a
 
345f2a
         char *data = wccp2_i_see_you.data;
345f2a
 
345f2a
-        header = (struct wccp2_item_header_t *) &data[offset];
345f2a
+        const auto itemHeader = reinterpret_cast<const wccp2_item_header_t*>(&data[offset]);
345f2a
+        const auto itemSize = CheckFieldDataLength(itemHeader, ntohs(itemHeader->length),
345f2a
+                              data, data_length, "truncated record");
345f2a
+        // XXX: Check "The specified length must be a multiple of 4 octets"
345f2a
+        // requirement to avoid unaligned memory reads after the first item.
345f2a
 
345f2a
-        switch (ntohs(header->type)) {
345f2a
+        switch (ntohs(itemHeader->type)) {
345f2a
 
345f2a
         case WCCP2_SECURITY_INFO:
345f2a
-
345f2a
-            if (security_info != NULL) {
345f2a
-                debugs(80, DBG_IMPORTANT, "Duplicate security definition");
345f2a
-                return;
345f2a
-            }
345f2a
-
345f2a
-            security_info = (struct wccp2_security_none_t *) &wccp2_i_see_you.data[offset];
345f2a
+            Must2(!security_info, "duplicate security definition");
345f2a
+            SetField(security_info, itemHeader, itemHeader, itemSize,
345f2a
+                     "security definition truncated");
345f2a
             break;
345f2a
 
345f2a
         case WCCP2_SERVICE_INFO:
345f2a
-
345f2a
-            if (service_info != NULL) {
345f2a
-                debugs(80, DBG_IMPORTANT, "Duplicate service_info definition");
345f2a
-                return;
345f2a
-            }
345f2a
-
345f2a
-            service_info = (struct wccp2_service_info_t *) &wccp2_i_see_you.data[offset];
345f2a
+            Must2(!service_info, "duplicate service_info definition");
345f2a
+            SetField(service_info, itemHeader, itemHeader, itemSize,
345f2a
+                     "service_info definition truncated");
345f2a
             break;
345f2a
 
345f2a
         case WCCP2_ROUTER_ID_INFO:
345f2a
-
345f2a
-            if (router_identity_info != NULL) {
345f2a
-                debugs(80, DBG_IMPORTANT, "Duplicate router_identity_info definition");
345f2a
-                return;
345f2a
-            }
345f2a
-
345f2a
-            router_identity_info = (struct router_identity_info_t *) &wccp2_i_see_you.data[offset];
345f2a
+            Must2(!router_identity_info, "duplicate router_identity_info definition");
345f2a
+            SetField(router_identity_info, itemHeader, itemHeader, itemSize,
345f2a
+                     "router_identity_info definition truncated");
345f2a
             break;
345f2a
 
345f2a
         case WCCP2_RTR_VIEW_INFO:
345f2a
-
345f2a
-            if (router_view_header != NULL) {
345f2a
-                debugs(80, DBG_IMPORTANT, "Duplicate router_view definition");
345f2a
-                return;
345f2a
-            }
345f2a
-
345f2a
-            router_view_header = (struct router_view_t *) &wccp2_i_see_you.data[offset];
345f2a
+            Must2(!router_view_header, "duplicate router_view definition");
345f2a
+            SetField(router_view_header, itemHeader, itemHeader, itemSize,
345f2a
+                     "router_view definition truncated");
345f2a
             break;
345f2a
 
345f2a
-        case WCCP2_CAPABILITY_INFO:
345f2a
-
345f2a
-            if (router_capability_header != NULL) {
345f2a
-                debugs(80, DBG_IMPORTANT, "Duplicate router_capability definition");
345f2a
-                return;
345f2a
-            }
345f2a
+        case WCCP2_CAPABILITY_INFO: {
345f2a
+            Must2(!router_capability_header, "duplicate router_capability definition");
345f2a
+            SetField(router_capability_header, itemHeader, itemHeader, itemSize,
345f2a
+                     "router_capability definition truncated");
345f2a
 
345f2a
-            router_capability_header = (struct wccp2_capability_info_header_t *) &wccp2_i_see_you.data[offset];
345f2a
+            CheckFieldDataLength(router_capability_header, ntohs(router_capability_header->capability_info_length),
345f2a
+                                 itemHeader, itemSize, "capability info truncated");
345f2a
+            router_capability_data_start = reinterpret_cast<char*>(router_capability_header) +
345f2a
+                                           sizeof(*router_capability_header);
345f2a
             break;
345f2a
+        }
345f2a
 
345f2a
         /* Nothing to do for the types below */
345f2a
 
345f2a
@@ -1257,22 +1292,17 @@ wccp2HandleUdp(int sock, void *)
345f2a
             break;
345f2a
 
345f2a
         default:
345f2a
-            debugs(80, DBG_IMPORTANT, "Unknown record type in WCCPv2 Packet (" << ntohs(header->type) << ").");
345f2a
+            debugs(80, DBG_IMPORTANT, "Unknown record type in WCCPv2 Packet (" << ntohs(itemHeader->type) << ").");
345f2a
         }
345f2a
 
345f2a
-        offset += sizeof(struct wccp2_item_header_t);
345f2a
-        offset += ntohs(header->length);
345f2a
-
345f2a
-        if (offset > data_length) {
345f2a
-            debugs(80, DBG_IMPORTANT, "Error: WCCPv2 packet tried to tell us there is data beyond the end of the packet");
345f2a
-            return;
345f2a
-        }
345f2a
+        offset += itemSize;
345f2a
+        assert(offset <= data_length && "CheckFieldDataLength(itemHeader...) established that");
345f2a
     }
345f2a
 
345f2a
-    if ((security_info == NULL) || (service_info == NULL) || (router_identity_info == NULL) || (router_view_header == NULL)) {
345f2a
-        debugs(80, DBG_IMPORTANT, "Incomplete WCCPv2 Packet");
345f2a
-        return;
345f2a
-    }
345f2a
+    Must2(security_info, "packet missing security definition");
345f2a
+    Must2(service_info, "packet missing service_info definition");
345f2a
+    Must2(router_identity_info, "packet missing router_identity_info definition");
345f2a
+    Must2(router_view_header, "packet missing router_view definition");
345f2a
 
345f2a
     debugs(80, 5, "Complete packet received");
345f2a
 
345f2a
@@ -1308,10 +1338,7 @@ wccp2HandleUdp(int sock, void *)
345f2a
             break;
345f2a
     }
345f2a
 
345f2a
-    if (router_list_ptr->next == NULL) {
345f2a
-        debugs(80, DBG_IMPORTANT, "WCCPv2 Packet received from unknown router");
345f2a
-        return;
345f2a
-    }
345f2a
+    Must2(router_list_ptr->next, "packet received from unknown router");
345f2a
 
345f2a
     /* Set the router id */
345f2a
     router_list_ptr->info->router_address = router_identity_info->router_id_element.router_address;
345f2a
@@ -1331,11 +1358,20 @@ wccp2HandleUdp(int sock, void *)
345f2a
         }
345f2a
     } else {
345f2a
 
345f2a
-        char *end = ((char *) router_capability_header) + sizeof(*router_capability_header) + ntohs(router_capability_header->capability_info_length) - sizeof(struct wccp2_capability_info_header_t);
345f2a
-
345f2a
-        router_capability_element = (struct wccp2_capability_element_t *) (((char *) router_capability_header) + sizeof(*router_capability_header));
345f2a
-
345f2a
-        while ((char *) router_capability_element <= end) {
345f2a
+        const auto router_capability_data_length = ntohs(router_capability_header->capability_info_length);
345f2a
+        assert(router_capability_data_start);
345f2a
+        const auto router_capability_data_end = router_capability_data_start +
345f2a
+                                                router_capability_data_length;
345f2a
+        for (auto router_capability_data_current = router_capability_data_start;
345f2a
+                router_capability_data_current < router_capability_data_end;) {
345f2a
+
345f2a
+            SetField(router_capability_element, router_capability_data_current,
345f2a
+                     router_capability_data_start, router_capability_data_length,
345f2a
+                     "capability element header truncated");
345f2a
+            const auto elementSize = CheckFieldDataLength(
345f2a
+                                         router_capability_element, ntohs(router_capability_element->capability_length),
345f2a
+                                         router_capability_data_start, router_capability_data_length,
345f2a
+                                         "capability element truncated");
345f2a
 
345f2a
             switch (ntohs(router_capability_element->capability_type)) {
345f2a
 
345f2a
@@ -1377,7 +1413,7 @@ wccp2HandleUdp(int sock, void *)
345f2a
                 debugs(80, DBG_IMPORTANT, "Unknown capability type in WCCPv2 Packet (" << ntohs(router_capability_element->capability_type) << ").");
345f2a
             }
345f2a
 
345f2a
-            router_capability_element = (struct wccp2_capability_element_t *) (((char *) router_capability_element) + sizeof(struct wccp2_item_header_t) + ntohs(router_capability_element->capability_length));
345f2a
+            router_capability_data_current += elementSize;
345f2a
         }
345f2a
     }
345f2a
 
345f2a
@@ -1396,23 +1432,34 @@ wccp2HandleUdp(int sock, void *)
345f2a
     num_caches = 0;
345f2a
 
345f2a
     /* Check to see if we're the master cache and update the cache list */
345f2a
-    found = 0;
345f2a
+    bool found = false;
345f2a
     service_list_ptr->lowest_ip = 1;
345f2a
     cache_list_ptr = &router_list_ptr->cache_list_head;
345f2a
 
345f2a
     /* to find the list of caches, we start at the end of the router view header */
345f2a
 
345f2a
     ptr = (char *) (router_view_header) + sizeof(struct router_view_t);
345f2a
+    const auto router_view_size = sizeof(struct router_view_t) +
345f2a
+                                  ntohs(router_view_header->header.length);
345f2a
 
345f2a
     /* Then we read the number of routers */
345f2a
-    memcpy(&tmp, ptr, sizeof(tmp));
345f2a
+    const uint32_t *routerCountRaw = nullptr;
345f2a
+    SetField(routerCountRaw, ptr, router_view_header, router_view_size,
345f2a
+             "malformed packet (truncated router view info w/o number of routers)");
345f2a
 
345f2a
     /* skip the number plus all the ip's */
345f2a
-
345f2a
-    ptr += sizeof(tmp) + (ntohl(tmp) * sizeof(struct in_addr));
345f2a
+    ptr += sizeof(*routerCountRaw);
345f2a
+    const auto ipCount = ntohl(*routerCountRaw);
345f2a
+    const auto ipsSize = ipCount * sizeof(struct in_addr); // we check for unsigned overflow below
345f2a
+    Must2(ipsSize / sizeof(struct in_addr) != ipCount, "huge IP address count");
345f2a
+    CheckSectionLength(ptr, ipsSize, router_view_header, router_view_size, "invalid IP address count");
345f2a
+    ptr += ipsSize;
345f2a
 
345f2a
     /* Then read the number of caches */
345f2a
-    memcpy(&tmp, ptr, sizeof(tmp));
345f2a
+    const uint32_t *cacheCountRaw = nullptr;
345f2a
+    SetField(cacheCountRaw, ptr, router_view_header, router_view_size,
345f2a
+             "malformed packet (truncated router view info w/o cache count)");
345f2a
+    memcpy(&tmp, cacheCountRaw, sizeof(tmp)); // TODO: Replace tmp with cacheCount
345f2a
     ptr += sizeof(tmp);
345f2a
 
345f2a
     if (ntohl(tmp) != 0) {
345f2a
@@ -1426,7 +1473,8 @@ wccp2HandleUdp(int sock, void *)
345f2a
 
345f2a
             case WCCP2_ASSIGNMENT_METHOD_HASH:
345f2a
 
345f2a
-                cache_identity = (struct wccp2_cache_identity_info_t *) ptr;
345f2a
+                SetField(cache_identity, ptr, router_view_header, router_view_size,
345f2a
+                         "malformed packet (truncated router view info cache w/o assignment hash)");
345f2a
 
345f2a
                 ptr += sizeof(struct wccp2_cache_identity_info_t);
345f2a
 
345f2a
@@ -1437,13 +1485,15 @@ wccp2HandleUdp(int sock, void *)
345f2a
 
345f2a
             case WCCP2_ASSIGNMENT_METHOD_MASK:
345f2a
 
345f2a
-                cache_mask_info = (struct cache_mask_info_t *) ptr;
345f2a
+                SetField(cache_mask_info, ptr, router_view_header, router_view_size,
345f2a
+                         "malformed packet (truncated router view info cache w/o assignment mask)");
345f2a
 
345f2a
                 /* The mask assignment has an undocumented variable length entry here */
345f2a
 
345f2a
                 if (ntohl(cache_mask_info->num1) == 3) {
345f2a
 
345f2a
-                    cache_mask_identity = (struct wccp2_cache_mask_identity_info_t *) ptr;
345f2a
+                    SetField(cache_mask_identity, ptr, router_view_header, router_view_size,
345f2a
+                             "malformed packet (truncated router view info cache w/o assignment mask identity)");
345f2a
 
345f2a
                     ptr += sizeof(struct wccp2_cache_mask_identity_info_t);
345f2a
 
345f2a
@@ -1474,10 +1524,7 @@ wccp2HandleUdp(int sock, void *)
345f2a
             debugs (80, 5,  "checking cache list: (" << std::hex << cache_address.s_addr << ":" <<  router_list_ptr->local_ip.s_addr << ")");
345f2a
 
345f2a
             /* Check to see if it's the master, or us */
345f2a
-
345f2a
-            if (cache_address.s_addr == router_list_ptr->local_ip.s_addr) {
345f2a
-                found = 1;
345f2a
-            }
345f2a
+            found = found || (cache_address.s_addr == router_list_ptr->local_ip.s_addr);
345f2a
 
345f2a
             if (cache_address.s_addr < router_list_ptr->local_ip.s_addr) {
345f2a
                 service_list_ptr->lowest_ip = 0;
345f2a
@@ -1494,7 +1541,7 @@ wccp2HandleUdp(int sock, void *)
345f2a
         cache_list_ptr->next = NULL;
345f2a
 
345f2a
         service_list_ptr->lowest_ip = 1;
345f2a
-        found = 1;
345f2a
+        found = true;
345f2a
         num_caches = 1;
345f2a
     }
345f2a
 
345f2a
@@ -1502,7 +1549,7 @@ wccp2HandleUdp(int sock, void *)
345f2a
 
345f2a
     router_list_ptr->num_caches = htonl(num_caches);
345f2a
 
345f2a
-    if ((found == 1) && (service_list_ptr->lowest_ip == 1)) {
345f2a
+    if (found && (service_list_ptr->lowest_ip == 1)) {
345f2a
         if (ntohl(router_view_header->change_number) != router_list_ptr->member_change) {
345f2a
             debugs(80, 4, "Change detected - queueing up new assignment");
345f2a
             router_list_ptr->member_change = ntohl(router_view_header->change_number);
345f2a
@@ -1515,6 +1562,10 @@ wccp2HandleUdp(int sock, void *)
345f2a
         eventDelete(wccp2AssignBuckets, NULL);
345f2a
         debugs(80, 5, "I am not the lowest ip cache - not assigning buckets");
345f2a
     }
345f2a
+
345f2a
+    } catch (...) {
345f2a
+        debugs(80, DBG_IMPORTANT, "ERROR: Ignoring WCCPv2 message: " << CurrentException);
345f2a
+    }
345f2a
 }
345f2a
 
345f2a
 static void