Blob Blame History Raw
From 42b3c2ed11c1e62c1691f868a6796983f93c3beb Mon Sep 17 00:00:00 2001
From: Greg Hudson <ghudson@mit.edu>
Date: Wed, 9 Apr 2014 13:19:03 -0400
Subject: [PATCH 01/13] Simplify sendto_kdc.c

* Get rid of the "x" member of conn_state, which used to be a union
  but hasn't been since r14742.
* Define a structure type for the "out" member of conn_state.
* Rename incoming_krb5_message to incoming_message for brevity.
* Make the "pos" member of incoming_message an offset instead of a
  pointer, simplifying several present and future computations.
* Use "in" and "out" aliases to the conn_state in and out members
  where it improves brevity.
* Rename set_conn_state_msg_length to set_transport_message and give
  it a descriptive comment.
* Call set_transport_message from start_connection only, instead of
  once in add_connection and perhaps again in start_connection.  To
  make this possible, pass the original message argument to maybe_send
  and start_connection.
* Use make_data and empty_data helpers where appropriate.
---
 src/lib/krb5/os/sendto_kdc.c | 159 +++++++++++++++++++++----------------------
 1 file changed, 79 insertions(+), 80 deletions(-)

diff --git a/src/lib/krb5/os/sendto_kdc.c b/src/lib/krb5/os/sendto_kdc.c
index 67e2a60..5f781d3 100644
--- a/src/lib/krb5/os/sendto_kdc.c
+++ b/src/lib/krb5/os/sendto_kdc.c
@@ -76,30 +76,30 @@ static const char *const state_strings[] = {
 
 /* connection states */
 enum conn_states { INITIALIZING, CONNECTING, WRITING, READING, FAILED };
-struct incoming_krb5_message {
+struct incoming_message {
     size_t bufsizebytes_read;
     size_t bufsize;
+    size_t pos;
     char *buf;
-    char *pos;
     unsigned char bufsizebytes[4];
     size_t n_left;
 };
 
+struct outgoing_message {
+    sg_buf sgbuf[2];
+    sg_buf *sgp;
+    int sg_count;
+    unsigned char msg_len_buf[4];
+};
+
 struct conn_state {
     SOCKET fd;
     enum conn_states state;
     int (*service)(krb5_context context, struct conn_state *,
                    struct select_state *, int);
     struct remote_address addr;
-    struct {
-        struct {
-            sg_buf sgbuf[2];
-            sg_buf *sgp;
-            int sg_count;
-            unsigned char msg_len_buf[4];
-        } out;
-        struct incoming_krb5_message in;
-    } x;
+    struct incoming_message in;
+    struct outgoing_message out;
     krb5_data callback_buffer;
     size_t server_index;
     struct conn_state *next;
@@ -461,30 +461,31 @@ static int service_tcp_fd(krb5_context context, struct conn_state *conn,
 static int service_udp_fd(krb5_context context, struct conn_state *conn,
                           struct select_state *selstate, int ssflags);
 
+/* Set up the actual message we will send across the underlying transport to
+ * communicate the payload message, using one or both of state->out.sgbuf. */
 static void
-set_conn_state_msg_length (struct conn_state *state, const krb5_data *message)
+set_transport_message(struct conn_state *state, const krb5_data *message)
 {
-    if (!message || message->length == 0)
+    struct outgoing_message *out = &state->out;
+
+    if (message == NULL || message->length == 0)
         return;
 
     if (state->addr.type == SOCK_STREAM) {
-        store_32_be(message->length, state->x.out.msg_len_buf);
-        SG_SET(&state->x.out.sgbuf[0], state->x.out.msg_len_buf, 4);
-        SG_SET(&state->x.out.sgbuf[1], message->data, message->length);
-        state->x.out.sg_count = 2;
-
+        store_32_be(message->length, out->msg_len_buf);
+        SG_SET(&out->sgbuf[0], out->msg_len_buf, 4);
+        SG_SET(&out->sgbuf[1], message->data, message->length);
+        out->sg_count = 2;
     } else {
-
-        SG_SET(&state->x.out.sgbuf[0], message->data, message->length);
-        SG_SET(&state->x.out.sgbuf[1], 0, 0);
-        state->x.out.sg_count = 1;
-
+        SG_SET(&out->sgbuf[0], message->data, message->length);
+        SG_SET(&out->sgbuf[1], NULL, 0);
+        out->sg_count = 1;
     }
 }
 
 static krb5_error_code
 add_connection(struct conn_state **conns, struct addrinfo *ai,
-               size_t server_index, const krb5_data *message, char **udpbufp)
+               size_t server_index, char **udpbufp)
 {
     struct conn_state *state, **tailptr;
 
@@ -492,28 +493,26 @@ add_connection(struct conn_state **conns, struct addrinfo *ai,
     if (state == NULL)
         return ENOMEM;
     state->state = INITIALIZING;
-    state->x.out.sgp = state->x.out.sgbuf;
+    state->out.sgp = state->out.sgbuf;
     state->addr.type = ai->ai_socktype;
     state->addr.family = ai->ai_family;
     state->addr.len = ai->ai_addrlen;
     memcpy(&state->addr.saddr, ai->ai_addr, ai->ai_addrlen);
     state->fd = INVALID_SOCKET;
     state->server_index = server_index;
-    SG_SET(&state->x.out.sgbuf[1], 0, 0);
+    SG_SET(&state->out.sgbuf[1], NULL, 0);
     if (ai->ai_socktype == SOCK_STREAM) {
         state->service = service_tcp_fd;
-        set_conn_state_msg_length (state, message);
     } else {
         state->service = service_udp_fd;
-        set_conn_state_msg_length (state, message);
 
         if (*udpbufp == NULL) {
             *udpbufp = malloc(MAX_DGRAM_SIZE);
             if (*udpbufp == 0)
                 return ENOMEM;
         }
-        state->x.in.buf = *udpbufp;
-        state->x.in.bufsize = MAX_DGRAM_SIZE;
+        state->in.buf = *udpbufp;
+        state->in.bufsize = MAX_DGRAM_SIZE;
     }
 
     /* Chain the new state onto the tail of the list. */
@@ -597,7 +596,7 @@ resolve_server(krb5_context context, const struct serverlist *servers,
         ai.ai_family = entry->family;
         ai.ai_addrlen = entry->addrlen;
         ai.ai_addr = (struct sockaddr *)&entry->addr;
-        return add_connection(conns, &ai, ind, message, udpbufp);
+        return add_connection(conns, &ai, ind, udpbufp);
     }
 
     memset(&hint, 0, sizeof(hint));
@@ -617,12 +616,12 @@ resolve_server(krb5_context context, const struct serverlist *servers,
     /* Add each address with the preferred socktype. */
     retval = 0;
     for (a = addrs; a != 0 && retval == 0; a = a->ai_next)
-        retval = add_connection(conns, a, ind, message, udpbufp);
+        retval = add_connection(conns, a, ind, udpbufp);
     if (retval == 0 && entry->socktype == 0 && socktype2 != 0) {
         /* Add each address again with the non-preferred socktype. */
         for (a = addrs; a != 0 && retval == 0; a = a->ai_next) {
             a->ai_socktype = socktype2;
-            retval = add_connection(conns, a, ind, message, udpbufp);
+            retval = add_connection(conns, a, ind, udpbufp);
         }
     }
     freeaddrinfo(addrs);
@@ -631,7 +630,7 @@ resolve_server(krb5_context context, const struct serverlist *servers,
 
 static int
 start_connection(krb5_context context, struct conn_state *state,
-                 struct select_state *selstate,
+                 const krb5_data *message, struct select_state *selstate,
                  struct sendto_callback_info *callback_info)
 {
     int fd, e;
@@ -689,13 +688,14 @@ start_connection(krb5_context context, struct conn_state *state,
             return -3;
         }
 
-        set_conn_state_msg_length(state, &state->callback_buffer);
+        message = &state->callback_buffer;
     }
+    set_transport_message(state, message);
 
     if (state->addr.type == SOCK_DGRAM) {
         /* Send it now.  */
         ssize_t ret;
-        sg_buf *sg = &state->x.out.sgbuf[0];
+        sg_buf *sg = &state->out.sgbuf[0];
 
         TRACE_SENDTO_KDC_UDP_SEND_INITIAL(context, &state->addr);
         ret = send(state->fd, SG_BUF(sg), SG_LEN(sg), 0);
@@ -731,14 +731,16 @@ start_connection(krb5_context context, struct conn_state *state,
    next connection.  */
 static int
 maybe_send(krb5_context context, struct conn_state *conn,
-           struct select_state *selstate,
+           const krb5_data *message, struct select_state *selstate,
            struct sendto_callback_info *callback_info)
 {
     sg_buf *sg;
     ssize_t ret;
 
-    if (conn->state == INITIALIZING)
-        return start_connection(context, conn, selstate, callback_info);
+    if (conn->state == INITIALIZING) {
+        return start_connection(context, conn, message, selstate,
+                                callback_info);
+    }
 
     /* Did we already shut down this channel?  */
     if (conn->state == FAILED) {
@@ -752,7 +754,7 @@ maybe_send(krb5_context context, struct conn_state *conn,
     }
 
     /* UDP - retransmit after a previous attempt timed out. */
-    sg = &conn->x.out.sgbuf[0];
+    sg = &conn->out.sgbuf[0];
     TRACE_SENDTO_KDC_UDP_SEND_RETRY(context, &conn->addr);
     ret = send(conn->fd, SG_BUF(sg), SG_LEN(sg), 0);
     if (ret < 0 || (size_t) ret != SG_LEN(sg)) {
@@ -803,6 +805,8 @@ service_tcp_fd(krb5_context context, struct conn_state *conn,
     int e = 0;
     ssize_t nwritten, nread;
     SOCKET_WRITEV_TEMP tmp;
+    struct incoming_message *in = &conn->in;
+    struct outgoing_message *out = &conn->out;
 
     /* Check for a socket exception. */
     if (ssflags & SSF_EXCEPTION)
@@ -825,68 +829,68 @@ service_tcp_fd(krb5_context context, struct conn_state *conn,
         /* Fall through. */
     case WRITING:
         TRACE_SENDTO_KDC_TCP_SEND(context, &conn->addr);
-        nwritten = SOCKET_WRITEV(conn->fd, conn->x.out.sgp,
-                                 conn->x.out.sg_count, tmp);
+        nwritten = SOCKET_WRITEV(conn->fd, out->sgp, out->sg_count, tmp);
         if (nwritten < 0) {
             TRACE_SENDTO_KDC_TCP_ERROR_SEND(context, &conn->addr,
                                             SOCKET_ERRNO);
             goto kill_conn;
         }
         while (nwritten) {
-            sg_buf *sgp = conn->x.out.sgp;
+            sg_buf *sgp = out->sgp;
             if ((size_t) nwritten < SG_LEN(sgp)) {
                 SG_ADVANCE(sgp, (size_t) nwritten);
                 nwritten = 0;
             } else {
                 nwritten -= SG_LEN(sgp);
-                conn->x.out.sgp++;
-                conn->x.out.sg_count--;
+                out->sgp++;
+                out->sg_count--;
             }
         }
-        if (conn->x.out.sg_count == 0) {
+        if (out->sg_count == 0) {
             /* Done writing, switch to reading. */
             cm_read(selstate, conn->fd);
             conn->state = READING;
-            conn->x.in.bufsizebytes_read = 0;
-            conn->x.in.bufsize = 0;
-            conn->x.in.buf = 0;
-            conn->x.in.pos = 0;
-            conn->x.in.n_left = 0;
+            in->bufsizebytes_read = 0;
+            in->bufsize = 0;
+            in->pos = 0;
+            in->buf = NULL;
+            in->n_left = 0;
         }
         return 0;
 
     case READING:
-        if (conn->x.in.bufsizebytes_read == 4) {
+        if (in->bufsizebytes_read == 4) {
             /* Reading data.  */
-            nread = SOCKET_READ(conn->fd, conn->x.in.pos, conn->x.in.n_left);
+            nread = SOCKET_READ(conn->fd, &in->buf[in->pos], in->n_left);
             if (nread <= 0) {
                 e = nread ? SOCKET_ERRNO : ECONNRESET;
                 TRACE_SENDTO_KDC_TCP_ERROR_RECV(context, &conn->addr, e);
                 goto kill_conn;
             }
-            conn->x.in.n_left -= nread;
-            conn->x.in.pos += nread;
-            if (conn->x.in.n_left <= 0)
+            in->n_left -= nread;
+            in->pos += nread;
+            if (in->n_left <= 0)
                 return 1;
         } else {
             /* Reading length.  */
             nread = SOCKET_READ(conn->fd,
-                                conn->x.in.bufsizebytes + conn->x.in.bufsizebytes_read,
-                                4 - conn->x.in.bufsizebytes_read);
+                                in->bufsizebytes + in->bufsizebytes_read,
+                                4 - in->bufsizebytes_read);
             if (nread <= 0) {
                 e = nread ? SOCKET_ERRNO : ECONNRESET;
                 TRACE_SENDTO_KDC_TCP_ERROR_RECV_LEN(context, &conn->addr, e);
                 goto kill_conn;
             }
-            conn->x.in.bufsizebytes_read += nread;
-            if (conn->x.in.bufsizebytes_read == 4) {
-                unsigned long len = load_32_be (conn->x.in.bufsizebytes);
+            in->bufsizebytes_read += nread;
+            if (in->bufsizebytes_read == 4) {
+                unsigned long len = load_32_be(in->bufsizebytes);
                 /* Arbitrary 1M cap.  */
                 if (len > 1 * 1024 * 1024)
                     goto kill_conn;
-                conn->x.in.bufsize = conn->x.in.n_left = len;
-                conn->x.in.buf = conn->x.in.pos = malloc(len);
-                if (conn->x.in.buf == 0)
+                in->bufsize = in->n_left = len;
+                in->pos = 0;
+                in->buf = malloc(len);
+                if (in->buf == NULL)
                     goto kill_conn;
             }
         }
@@ -915,13 +919,13 @@ service_udp_fd(krb5_context context, struct conn_state *conn,
     if (conn->state != READING)
         abort();
 
-    nread = recv(conn->fd, conn->x.in.buf, conn->x.in.bufsize, 0);
+    nread = recv(conn->fd, conn->in.buf, conn->in.bufsize, 0);
     if (nread < 0) {
         TRACE_SENDTO_KDC_UDP_ERROR_RECV(context, &conn->addr, SOCKET_ERRNO);
         kill_conn(conn, selstate);
         return 0;
     }
-    conn->x.in.pos = conn->x.in.buf + nread;
+    conn->in.pos = nread;
     return 1;
 }
 
@@ -986,10 +990,7 @@ service_fds(krb5_context context, struct select_state *selstate,
                 int stop = 1;
 
                 if (msg_handler != NULL) {
-                    krb5_data reply;
-
-                    reply.data = state->x.in.buf;
-                    reply.length = state->x.in.pos - state->x.in.buf;
+                    krb5_data reply = make_data(state->in.buf, state->in.pos);
 
                     stop = (msg_handler(context, &reply, msg_handler_data) != 0);
                 }
@@ -1051,8 +1052,7 @@ k5_sendto(krb5_context context, const krb5_data *message,
     char *udpbuf = NULL;
     krb5_boolean done = FALSE;
 
-    reply->data = 0;
-    reply->length = 0;
+    *reply = empty_data();
 
     /* One for use here, listing all our fds in use, and one for
      * temporary use in service_fds, for the fds of interest.  */
@@ -1077,7 +1077,7 @@ k5_sendto(krb5_context context, const krb5_data *message,
             /* Contact each new connection whose socktype matches socktype1. */
             if (state->addr.type != socktype1)
                 continue;
-            if (maybe_send(context, state, sel_state, callback_info))
+            if (maybe_send(context, state, message, sel_state, callback_info))
                 continue;
             done = service_fds(context, sel_state, 1000, conns, seltemp,
                                msg_handler, msg_handler_data, &winner);
@@ -1089,7 +1089,7 @@ k5_sendto(krb5_context context, const krb5_data *message,
     for (state = conns; state != NULL && !done; state = state->next) {
         if (state->addr.type != socktype2)
             continue;
-        if (maybe_send(context, state, sel_state, callback_info))
+        if (maybe_send(context, state, message, sel_state, callback_info))
             continue;
         done = service_fds(context, sel_state, 1000, conns, seltemp,
                            msg_handler, msg_handler_data, &winner);
@@ -1105,7 +1105,7 @@ k5_sendto(krb5_context context, const krb5_data *message,
     delay = 4000;
     for (pass = 1; pass < MAX_PASS && !done; pass++) {
         for (state = conns; state != NULL && !done; state = state->next) {
-            if (maybe_send(context, state, sel_state, callback_info))
+            if (maybe_send(context, state, message, sel_state, callback_info))
                 continue;
             done = service_fds(context, sel_state, 1000, conns, seltemp,
                                msg_handler, msg_handler_data, &winner);
@@ -1127,10 +1127,9 @@ k5_sendto(krb5_context context, const krb5_data *message,
         goto cleanup;
     }
     /* Success!  */
-    reply->data = winner->x.in.buf;
-    reply->length = winner->x.in.pos - winner->x.in.buf;
+    *reply = make_data(winner->in.buf, winner->in.pos);
     retval = 0;
-    winner->x.in.buf = NULL;
+    winner->in.buf = NULL;
     if (server_used != NULL)
         *server_used = winner->server_index;
     if (remoteaddr != NULL && remoteaddrlen != 0 && *remoteaddrlen > 0)
@@ -1142,8 +1141,8 @@ cleanup:
         next = state->next;
         if (state->fd != INVALID_SOCKET)
             closesocket(state->fd);
-        if (state->state == READING && state->x.in.buf != udpbuf)
-            free(state->x.in.buf);
+        if (state->state == READING && state->in.buf != udpbuf)
+            free(state->in.buf);
         if (callback_info) {
             callback_info->pfn_cleanup(callback_info->data,
                                        &state->callback_buffer);
-- 
2.1.0