Blob Blame History Raw
Tweaked a bit to apply to 1.12:
* krb5_int64 hadn't been replaced by int64_t yet.

commit 346883c48f1b9e09b1af2cf73e3b96ee8f934072
Author: Greg Hudson <ghudson@mit.edu>
Date:   Wed Mar 26 13:21:45 2014 -0400

    Refactor cm functions in sendto_kdc.c
    
    Move get_curtime_ms and the cm functions near the top of the file
    right after structure definitions.  Except for cm_select_or_poll,
    define each cm function separately for poll and for select, since the
    implementations don't share much in common.  Instead of
    cm_unset_write, define cm_read and cm_write functions to put an fd in
    read-only or write-only state.  Remove the ssflags argument from
    cm_add_fd and just expect the caller to make a subsequent call to
    cm_read or cm_write.  Always select for exceptions when using select.
    (Polling for exceptions is implicit with poll).
    
    With these changes, we no longer select/poll for reading on a TCP
    connection until we are done writing to it.  So in service_tcp_fd,
    remove the check for unexpected read events.

diff --git a/src/lib/krb5/os/sendto_kdc.c b/src/lib/krb5/os/sendto_kdc.c
index e60a375..e773a0a 100644
--- a/src/lib/krb5/os/sendto_kdc.c
+++ b/src/lib/krb5/os/sendto_kdc.c
@@ -59,8 +59,7 @@
 
 typedef krb5_int64 time_ms;
 
-/* Since fd_set is large on some platforms (8K on AIX 5.2), this probably
- * shouldn't be allocated in automatic storage. */
+/* This can be pretty large, so should not be stack-allocated. */
 struct select_state {
 #ifdef USE_POLL
     struct pollfd fds[MAX_POLLFDS];
@@ -107,6 +106,183 @@ struct conn_state {
     time_ms endtime;
 };
 
+/* Get current time in milliseconds. */
+static krb5_error_code
+get_curtime_ms(time_ms *time_out)
+{
+    struct timeval tv;
+
+    if (gettimeofday(&tv, 0))
+        return errno;
+    *time_out = (time_ms)tv.tv_sec * 1000 + tv.tv_usec / 1000;
+    return 0;
+}
+
+#ifdef USE_POLL
+
+/* Find a pollfd in selstate by fd, or abort if we can't find it. */
+static inline struct pollfd *
+find_pollfd(struct select_state *selstate, int fd)
+{
+    int i;
+
+    for (i = 0; i < selstate->nfds; i++) {
+        if (selstate->fds[i].fd == fd)
+            return &selstate->fds[i];
+    }
+    abort();
+}
+
+static void
+cm_init_selstate(struct select_state *selstate)
+{
+    selstate->nfds = 0;
+}
+
+static krb5_boolean
+cm_add_fd(struct select_state *selstate, int fd)
+{
+    if (selstate->nfds >= MAX_POLLFDS)
+        return FALSE;
+    selstate->fds[selstate->nfds].fd = fd;
+    selstate->fds[selstate->nfds].events = 0;
+    selstate->nfds++;
+    return TRUE;
+}
+
+static void
+cm_remove_fd(struct select_state *selstate, int fd)
+{
+    struct pollfd *pfd = find_pollfd(selstate, fd);
+
+    *pfd = selstate->fds[selstate->nfds - 1];
+    selstate->nfds--;
+}
+
+/* Poll for reading (and not writing) on fd the next time we poll. */
+static void
+cm_read(struct select_state *selstate, int fd)
+{
+    find_pollfd(selstate, fd)->events = POLLIN;
+}
+
+/* Poll for writing (and not reading) on fd the next time we poll. */
+static void
+cm_write(struct select_state *selstate, int fd)
+{
+    find_pollfd(selstate, fd)->events = POLLOUT;
+}
+
+/* Get the output events for fd in the form of ssflags. */
+static unsigned int
+cm_get_ssflags(struct select_state *selstate, int fd)
+{
+    struct pollfd *pfd = find_pollfd(selstate, fd);
+
+    return ((pfd->revents & POLLIN) ? SSF_READ : 0) |
+        ((pfd->revents & POLLOUT) ? SSF_WRITE : 0) |
+        ((pfd->revents & POLLERR) ? SSF_EXCEPTION : 0);
+}
+
+#else /* not USE_POLL */
+
+static void
+cm_init_selstate(struct select_state *selstate)
+{
+    selstate->nfds = 0;
+    selstate->max = 0;
+    FD_ZERO(&selstate->rfds);
+    FD_ZERO(&selstate->wfds);
+    FD_ZERO(&selstate->xfds);
+}
+
+static krb5_boolean
+cm_add_fd(struct select_state *selstate, int fd)
+{
+#ifndef _WIN32  /* On Windows FD_SETSIZE is a count, not a max value. */
+    if (fd >= FD_SETSIZE)
+        return FALSE;
+#endif
+    FD_SET(fd, &selstate->xfds);
+    if (selstate->max <= fd)
+        selstate->max = fd + 1;
+    selstate->nfds++;
+    return TRUE;
+}
+
+static void
+cm_remove_fd(struct select_state *selstate, int fd)
+{
+    FD_CLR(fd, &selstate->rfds);
+    FD_CLR(fd, &selstate->wfds);
+    FD_CLR(fd, &selstate->xfds);
+    if (selstate->max == fd + 1) {
+        while (selstate->max > 0 &&
+               !FD_ISSET(selstate->max - 1, &selstate->rfds) &&
+               !FD_ISSET(selstate->max - 1, &selstate->wfds) &&
+               !FD_ISSET(selstate->max - 1, &selstate->xfds))
+            selstate->max--;
+    }
+    selstate->nfds--;
+}
+
+/* Select for reading (and not writing) on fd the next time we select. */
+static void
+cm_read(struct select_state *selstate, int fd)
+{
+    FD_SET(fd, &selstate->rfds);
+    FD_CLR(fd, &selstate->wfds);
+}
+
+/* Select for writing (and not reading) on fd the next time we select. */
+static void
+cm_write(struct select_state *selstate, int fd)
+{
+    FD_CLR(fd, &selstate->rfds);
+    FD_SET(fd, &selstate->wfds);
+}
+
+/* Get the events for fd from selstate after a select. */
+static unsigned int
+cm_get_ssflags(struct select_state *selstate, int fd)
+{
+    return (FD_ISSET(fd, &selstate->rfds) ? SSF_READ : 0) |
+        (FD_ISSET(fd, &selstate->wfds) ? SSF_WRITE : 0) |
+        (FD_ISSET(fd, &selstate->xfds) ? SSF_EXCEPTION : 0);
+}
+
+#endif /* not USE_POLL */
+
+static krb5_error_code
+cm_select_or_poll(const struct select_state *in, time_ms endtime,
+                  struct select_state *out, int *sret)
+{
+#ifndef USE_POLL
+    struct timeval tv;
+#endif
+    krb5_error_code retval;
+    time_ms curtime, interval;
+
+    retval = get_curtime_ms(&curtime);
+    if (retval != 0)
+        return retval;
+    interval = (curtime < endtime) ? endtime - curtime : 0;
+
+    /* We don't need a separate copy of the selstate for poll, but use one for
+     * consistency with how we use select. */
+    *out = *in;
+
+#ifdef USE_POLL
+    *sret = poll(out->fds, out->nfds, interval);
+#else
+    tv.tv_sec = interval / 1000;
+    tv.tv_usec = interval % 1000 * 1000;
+    *sret = select(out->max, &out->rfds, &out->wfds, &out->xfds, &tv);
+#endif
+
+    return (*sret < 0) ? SOCKET_ERRNO : 0;
+}
+
 static int
 in_addrlist(struct server_entry *entry, struct serverlist *list)
 {
@@ -251,18 +427,6 @@ cleanup:
     return retval;
 }
 
-/* Get current time in milliseconds. */
-static krb5_error_code
-get_curtime_ms(time_ms *time_out)
-{
-    struct timeval tv;
-
-    if (gettimeofday(&tv, 0))
-        return errno;
-    *time_out = (time_ms)tv.tv_sec * 1000 + tv.tv_usec / 1000;
-    return 0;
-}
-
 /*
  * Notes:
  *
@@ -283,144 +447,6 @@ get_curtime_ms(time_ms *time_out)
  *   connections already in progress
  */
 
-static void
-cm_init_selstate(struct select_state *selstate)
-{
-    selstate->nfds = 0;
-#ifndef USE_POLL
-    selstate->max = 0;
-    FD_ZERO(&selstate->rfds);
-    FD_ZERO(&selstate->wfds);
-    FD_ZERO(&selstate->xfds);
-#endif
-}
-
-static krb5_boolean
-cm_add_fd(struct select_state *selstate, int fd, unsigned int ssflags)
-{
-#ifdef USE_POLL
-    if (selstate->nfds >= MAX_POLLFDS)
-        return FALSE;
-    selstate->fds[selstate->nfds].fd = fd;
-    selstate->fds[selstate->nfds].events = 0;
-    if (ssflags & SSF_READ)
-        selstate->fds[selstate->nfds].events |= POLLIN;
-    if (ssflags & SSF_WRITE)
-        selstate->fds[selstate->nfds].events |= POLLOUT;
-#else
-#ifndef _WIN32  /* On Windows FD_SETSIZE is a count, not a max value. */
-    if (fd >= FD_SETSIZE)
-        return FALSE;
-#endif
-    if (ssflags & SSF_READ)
-        FD_SET(fd, &selstate->rfds);
-    if (ssflags & SSF_WRITE)
-        FD_SET(fd, &selstate->wfds);
-    if (ssflags & SSF_EXCEPTION)
-        FD_SET(fd, &selstate->xfds);
-    if (selstate->max <= fd)
-        selstate->max = fd + 1;
-#endif
-    selstate->nfds++;
-    return TRUE;
-}
-
-static void
-cm_remove_fd(struct select_state *selstate, int fd)
-{
-#ifdef USE_POLL
-    int i;
-
-    /* Find the FD in the array and move the last entry to its place. */
-    assert(selstate->nfds > 0);
-    for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++);
-    assert(i < selstate->nfds);
-    selstate->fds[i] = selstate->fds[selstate->nfds - 1];
-#else
-    FD_CLR(fd, &selstate->rfds);
-    FD_CLR(fd, &selstate->wfds);
-    FD_CLR(fd, &selstate->xfds);
-    if (selstate->max == 1 + fd) {
-        while (selstate->max > 0
-               && ! FD_ISSET(selstate->max-1, &selstate->rfds)
-               && ! FD_ISSET(selstate->max-1, &selstate->wfds)
-               && ! FD_ISSET(selstate->max-1, &selstate->xfds))
-            selstate->max--;
-    }
-#endif
-    selstate->nfds--;
-}
-
-static void
-cm_unset_write(struct select_state *selstate, int fd)
-{
-#ifdef USE_POLL
-    int i;
-
-    for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++);
-    assert(i < selstate->nfds);
-    selstate->fds[i].events &= ~POLLOUT;
-#else
-    FD_CLR(fd, &selstate->wfds);
-#endif
-}
-
-static krb5_error_code
-cm_select_or_poll(const struct select_state *in, time_ms endtime,
-                  struct select_state *out, int *sret)
-{
-#ifndef USE_POLL
-    struct timeval tv;
-#endif
-    krb5_error_code retval;
-    time_ms curtime, interval;
-
-    retval = get_curtime_ms(&curtime);
-    if (retval != 0)
-        return retval;
-    interval = (curtime < endtime) ? endtime - curtime : 0;
-
-    /* We don't need a separate copy of the selstate for poll, but use one for
-     * consistency with how we use select. */
-    *out = *in;
-
-#ifdef USE_POLL
-    *sret = poll(out->fds, out->nfds, interval);
-#else
-    tv.tv_sec = interval / 1000;
-    tv.tv_usec = interval % 1000 * 1000;
-    *sret = select(out->max, &out->rfds, &out->wfds, &out->xfds, &tv);
-#endif
-
-    return (*sret < 0) ? SOCKET_ERRNO : 0;
-}
-
-static unsigned int
-cm_get_ssflags(struct select_state *selstate, int fd)
-{
-    unsigned int ssflags = 0;
-#ifdef USE_POLL
-    int i;
-
-    for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++);
-    assert(i < selstate->nfds);
-    if (selstate->fds[i].revents & POLLIN)
-        ssflags |= SSF_READ;
-    if (selstate->fds[i].revents & POLLOUT)
-        ssflags |= SSF_WRITE;
-    if (selstate->fds[i].revents & POLLERR)
-        ssflags |= SSF_EXCEPTION;
-#else
-    if (FD_ISSET(fd, &selstate->rfds))
-        ssflags |= SSF_READ;
-    if (FD_ISSET(fd, &selstate->wfds))
-        ssflags |= SSF_WRITE;
-    if (FD_ISSET(fd, &selstate->xfds))
-        ssflags |= SSF_EXCEPTION;
-#endif
-    return ssflags;
-}
-
 static int service_tcp_fd(krb5_context context, struct conn_state *conn,
                           struct select_state *selstate, int ssflags);
 static int service_udp_fd(krb5_context context, struct conn_state *conn,
@@ -600,7 +626,6 @@ start_connection(krb5_context context, struct conn_state *state,
                  struct sendto_callback_info *callback_info)
 {
     int fd, e;
-    unsigned int ssflags;
     static const int one = 1;
     static const struct linger lopt = { 0, 0 };
 
@@ -676,15 +701,17 @@ start_connection(krb5_context context, struct conn_state *state,
             state->state = READING;
         }
     }
-    ssflags = SSF_READ | SSF_EXCEPTION;
-    if (state->state == CONNECTING || state->state == WRITING)
-        ssflags |= SSF_WRITE;
-    if (!cm_add_fd(selstate, state->fd, ssflags)) {
+
+    if (!cm_add_fd(selstate, state->fd)) {
         (void) closesocket(state->fd);
         state->fd = INVALID_SOCKET;
         state->state = FAILED;
         return -1;
     }
+    if (state->state == CONNECTING || state->state == WRITING)
+        cm_write(selstate, state->fd);
+    else
+        cm_read(selstate, state->fd);
 
     return 0;
 }
@@ -768,9 +795,8 @@ service_tcp_fd(krb5_context context, struct conn_state *conn,
     ssize_t nwritten, nread;
     SOCKET_WRITEV_TEMP tmp;
 
-    /* Check for a socket exception or readable data before we expect it. */
-    if (ssflags & SSF_EXCEPTION ||
-        ((ssflags & SSF_READ) && conn->state != READING))
+    /* Check for a socket exception. */
+    if (ssflags & SSF_EXCEPTION)
         goto kill_conn;
 
     switch (conn->state) {
@@ -810,7 +836,7 @@ service_tcp_fd(krb5_context context, struct conn_state *conn,
         }
         if (conn->x.out.sg_count == 0) {
             /* Done writing, switch to reading. */
-            cm_unset_write(selstate, conn->fd);
+            cm_read(selstate, conn->fd);
             conn->state = READING;
             conn->x.in.bufsizebytes_read = 0;
             conn->x.in.bufsize = 0;