Tweaked a bit to apply to 1.12:
* krb5_int64 hadn't been replaced by int64_t yet.
commit 346883c48f1b9e09b1af2cf73e3b96ee8f934072
Author: Greg Hudson <ghudson@mit.edu>
Date: Wed Mar 26 13:21:45 2014 -0400
Refactor cm functions in sendto_kdc.c
Move get_curtime_ms and the cm functions near the top of the file
right after structure definitions. Except for cm_select_or_poll,
define each cm function separately for poll and for select, since the
implementations don't share much in common. Instead of
cm_unset_write, define cm_read and cm_write functions to put an fd in
read-only or write-only state. Remove the ssflags argument from
cm_add_fd and just expect the caller to make a subsequent call to
cm_read or cm_write. Always select for exceptions when using select.
(Polling for exceptions is implicit with poll).
With these changes, we no longer select/poll for reading on a TCP
connection until we are done writing to it. So in service_tcp_fd,
remove the check for unexpected read events.
diff --git a/src/lib/krb5/os/sendto_kdc.c b/src/lib/krb5/os/sendto_kdc.c
index e60a375..e773a0a 100644
--- a/src/lib/krb5/os/sendto_kdc.c
+++ b/src/lib/krb5/os/sendto_kdc.c
@@ -59,8 +59,7 @@
typedef krb5_int64 time_ms;
-/* Since fd_set is large on some platforms (8K on AIX 5.2), this probably
- * shouldn't be allocated in automatic storage. */
+/* This can be pretty large, so should not be stack-allocated. */
struct select_state {
#ifdef USE_POLL
struct pollfd fds[MAX_POLLFDS];
@@ -107,6 +106,183 @@ struct conn_state {
time_ms endtime;
};
+/* Get current time in milliseconds. */
+static krb5_error_code
+get_curtime_ms(time_ms *time_out)
+{
+ struct timeval tv;
+
+ if (gettimeofday(&tv, 0))
+ return errno;
+ *time_out = (time_ms)tv.tv_sec * 1000 + tv.tv_usec / 1000;
+ return 0;
+}
+
+#ifdef USE_POLL
+
+/* Find a pollfd in selstate by fd, or abort if we can't find it. */
+static inline struct pollfd *
+find_pollfd(struct select_state *selstate, int fd)
+{
+ int i;
+
+ for (i = 0; i < selstate->nfds; i++) {
+ if (selstate->fds[i].fd == fd)
+ return &selstate->fds[i];
+ }
+ abort();
+}
+
+static void
+cm_init_selstate(struct select_state *selstate)
+{
+ selstate->nfds = 0;
+}
+
+static krb5_boolean
+cm_add_fd(struct select_state *selstate, int fd)
+{
+ if (selstate->nfds >= MAX_POLLFDS)
+ return FALSE;
+ selstate->fds[selstate->nfds].fd = fd;
+ selstate->fds[selstate->nfds].events = 0;
+ selstate->nfds++;
+ return TRUE;
+}
+
+static void
+cm_remove_fd(struct select_state *selstate, int fd)
+{
+ struct pollfd *pfd = find_pollfd(selstate, fd);
+
+ *pfd = selstate->fds[selstate->nfds - 1];
+ selstate->nfds--;
+}
+
+/* Poll for reading (and not writing) on fd the next time we poll. */
+static void
+cm_read(struct select_state *selstate, int fd)
+{
+ find_pollfd(selstate, fd)->events = POLLIN;
+}
+
+/* Poll for writing (and not reading) on fd the next time we poll. */
+static void
+cm_write(struct select_state *selstate, int fd)
+{
+ find_pollfd(selstate, fd)->events = POLLOUT;
+}
+
+/* Get the output events for fd in the form of ssflags. */
+static unsigned int
+cm_get_ssflags(struct select_state *selstate, int fd)
+{
+ struct pollfd *pfd = find_pollfd(selstate, fd);
+
+ return ((pfd->revents & POLLIN) ? SSF_READ : 0) |
+ ((pfd->revents & POLLOUT) ? SSF_WRITE : 0) |
+ ((pfd->revents & POLLERR) ? SSF_EXCEPTION : 0);
+}
+
+#else /* not USE_POLL */
+
+static void
+cm_init_selstate(struct select_state *selstate)
+{
+ selstate->nfds = 0;
+ selstate->max = 0;
+ FD_ZERO(&selstate->rfds);
+ FD_ZERO(&selstate->wfds);
+ FD_ZERO(&selstate->xfds);
+}
+
+static krb5_boolean
+cm_add_fd(struct select_state *selstate, int fd)
+{
+#ifndef _WIN32 /* On Windows FD_SETSIZE is a count, not a max value. */
+ if (fd >= FD_SETSIZE)
+ return FALSE;
+#endif
+ FD_SET(fd, &selstate->xfds);
+ if (selstate->max <= fd)
+ selstate->max = fd + 1;
+ selstate->nfds++;
+ return TRUE;
+}
+
+static void
+cm_remove_fd(struct select_state *selstate, int fd)
+{
+ FD_CLR(fd, &selstate->rfds);
+ FD_CLR(fd, &selstate->wfds);
+ FD_CLR(fd, &selstate->xfds);
+ if (selstate->max == fd + 1) {
+ while (selstate->max > 0 &&
+ !FD_ISSET(selstate->max - 1, &selstate->rfds) &&
+ !FD_ISSET(selstate->max - 1, &selstate->wfds) &&
+ !FD_ISSET(selstate->max - 1, &selstate->xfds))
+ selstate->max--;
+ }
+ selstate->nfds--;
+}
+
+/* Select for reading (and not writing) on fd the next time we select. */
+static void
+cm_read(struct select_state *selstate, int fd)
+{
+ FD_SET(fd, &selstate->rfds);
+ FD_CLR(fd, &selstate->wfds);
+}
+
+/* Select for writing (and not reading) on fd the next time we select. */
+static void
+cm_write(struct select_state *selstate, int fd)
+{
+ FD_CLR(fd, &selstate->rfds);
+ FD_SET(fd, &selstate->wfds);
+}
+
+/* Get the events for fd from selstate after a select. */
+static unsigned int
+cm_get_ssflags(struct select_state *selstate, int fd)
+{
+ return (FD_ISSET(fd, &selstate->rfds) ? SSF_READ : 0) |
+ (FD_ISSET(fd, &selstate->wfds) ? SSF_WRITE : 0) |
+ (FD_ISSET(fd, &selstate->xfds) ? SSF_EXCEPTION : 0);
+}
+
+#endif /* not USE_POLL */
+
+static krb5_error_code
+cm_select_or_poll(const struct select_state *in, time_ms endtime,
+ struct select_state *out, int *sret)
+{
+#ifndef USE_POLL
+ struct timeval tv;
+#endif
+ krb5_error_code retval;
+ time_ms curtime, interval;
+
+ retval = get_curtime_ms(&curtime);
+ if (retval != 0)
+ return retval;
+ interval = (curtime < endtime) ? endtime - curtime : 0;
+
+ /* We don't need a separate copy of the selstate for poll, but use one for
+ * consistency with how we use select. */
+ *out = *in;
+
+#ifdef USE_POLL
+ *sret = poll(out->fds, out->nfds, interval);
+#else
+ tv.tv_sec = interval / 1000;
+ tv.tv_usec = interval % 1000 * 1000;
+ *sret = select(out->max, &out->rfds, &out->wfds, &out->xfds, &tv);
+#endif
+
+ return (*sret < 0) ? SOCKET_ERRNO : 0;
+}
+
static int
in_addrlist(struct server_entry *entry, struct serverlist *list)
{
@@ -251,18 +427,6 @@ cleanup:
return retval;
}
-/* Get current time in milliseconds. */
-static krb5_error_code
-get_curtime_ms(time_ms *time_out)
-{
- struct timeval tv;
-
- if (gettimeofday(&tv, 0))
- return errno;
- *time_out = (time_ms)tv.tv_sec * 1000 + tv.tv_usec / 1000;
- return 0;
-}
-
/*
* Notes:
*
@@ -283,144 +447,6 @@ get_curtime_ms(time_ms *time_out)
* connections already in progress
*/
-static void
-cm_init_selstate(struct select_state *selstate)
-{
- selstate->nfds = 0;
-#ifndef USE_POLL
- selstate->max = 0;
- FD_ZERO(&selstate->rfds);
- FD_ZERO(&selstate->wfds);
- FD_ZERO(&selstate->xfds);
-#endif
-}
-
-static krb5_boolean
-cm_add_fd(struct select_state *selstate, int fd, unsigned int ssflags)
-{
-#ifdef USE_POLL
- if (selstate->nfds >= MAX_POLLFDS)
- return FALSE;
- selstate->fds[selstate->nfds].fd = fd;
- selstate->fds[selstate->nfds].events = 0;
- if (ssflags & SSF_READ)
- selstate->fds[selstate->nfds].events |= POLLIN;
- if (ssflags & SSF_WRITE)
- selstate->fds[selstate->nfds].events |= POLLOUT;
-#else
-#ifndef _WIN32 /* On Windows FD_SETSIZE is a count, not a max value. */
- if (fd >= FD_SETSIZE)
- return FALSE;
-#endif
- if (ssflags & SSF_READ)
- FD_SET(fd, &selstate->rfds);
- if (ssflags & SSF_WRITE)
- FD_SET(fd, &selstate->wfds);
- if (ssflags & SSF_EXCEPTION)
- FD_SET(fd, &selstate->xfds);
- if (selstate->max <= fd)
- selstate->max = fd + 1;
-#endif
- selstate->nfds++;
- return TRUE;
-}
-
-static void
-cm_remove_fd(struct select_state *selstate, int fd)
-{
-#ifdef USE_POLL
- int i;
-
- /* Find the FD in the array and move the last entry to its place. */
- assert(selstate->nfds > 0);
- for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++);
- assert(i < selstate->nfds);
- selstate->fds[i] = selstate->fds[selstate->nfds - 1];
-#else
- FD_CLR(fd, &selstate->rfds);
- FD_CLR(fd, &selstate->wfds);
- FD_CLR(fd, &selstate->xfds);
- if (selstate->max == 1 + fd) {
- while (selstate->max > 0
- && ! FD_ISSET(selstate->max-1, &selstate->rfds)
- && ! FD_ISSET(selstate->max-1, &selstate->wfds)
- && ! FD_ISSET(selstate->max-1, &selstate->xfds))
- selstate->max--;
- }
-#endif
- selstate->nfds--;
-}
-
-static void
-cm_unset_write(struct select_state *selstate, int fd)
-{
-#ifdef USE_POLL
- int i;
-
- for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++);
- assert(i < selstate->nfds);
- selstate->fds[i].events &= ~POLLOUT;
-#else
- FD_CLR(fd, &selstate->wfds);
-#endif
-}
-
-static krb5_error_code
-cm_select_or_poll(const struct select_state *in, time_ms endtime,
- struct select_state *out, int *sret)
-{
-#ifndef USE_POLL
- struct timeval tv;
-#endif
- krb5_error_code retval;
- time_ms curtime, interval;
-
- retval = get_curtime_ms(&curtime);
- if (retval != 0)
- return retval;
- interval = (curtime < endtime) ? endtime - curtime : 0;
-
- /* We don't need a separate copy of the selstate for poll, but use one for
- * consistency with how we use select. */
- *out = *in;
-
-#ifdef USE_POLL
- *sret = poll(out->fds, out->nfds, interval);
-#else
- tv.tv_sec = interval / 1000;
- tv.tv_usec = interval % 1000 * 1000;
- *sret = select(out->max, &out->rfds, &out->wfds, &out->xfds, &tv);
-#endif
-
- return (*sret < 0) ? SOCKET_ERRNO : 0;
-}
-
-static unsigned int
-cm_get_ssflags(struct select_state *selstate, int fd)
-{
- unsigned int ssflags = 0;
-#ifdef USE_POLL
- int i;
-
- for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++);
- assert(i < selstate->nfds);
- if (selstate->fds[i].revents & POLLIN)
- ssflags |= SSF_READ;
- if (selstate->fds[i].revents & POLLOUT)
- ssflags |= SSF_WRITE;
- if (selstate->fds[i].revents & POLLERR)
- ssflags |= SSF_EXCEPTION;
-#else
- if (FD_ISSET(fd, &selstate->rfds))
- ssflags |= SSF_READ;
- if (FD_ISSET(fd, &selstate->wfds))
- ssflags |= SSF_WRITE;
- if (FD_ISSET(fd, &selstate->xfds))
- ssflags |= SSF_EXCEPTION;
-#endif
- return ssflags;
-}
-
static int service_tcp_fd(krb5_context context, struct conn_state *conn,
struct select_state *selstate, int ssflags);
static int service_udp_fd(krb5_context context, struct conn_state *conn,
@@ -600,7 +626,6 @@ start_connection(krb5_context context, struct conn_state *state,
struct sendto_callback_info *callback_info)
{
int fd, e;
- unsigned int ssflags;
static const int one = 1;
static const struct linger lopt = { 0, 0 };
@@ -676,15 +701,17 @@ start_connection(krb5_context context, struct conn_state *state,
state->state = READING;
}
}
- ssflags = SSF_READ | SSF_EXCEPTION;
- if (state->state == CONNECTING || state->state == WRITING)
- ssflags |= SSF_WRITE;
- if (!cm_add_fd(selstate, state->fd, ssflags)) {
+
+ if (!cm_add_fd(selstate, state->fd)) {
(void) closesocket(state->fd);
state->fd = INVALID_SOCKET;
state->state = FAILED;
return -1;
}
+ if (state->state == CONNECTING || state->state == WRITING)
+ cm_write(selstate, state->fd);
+ else
+ cm_read(selstate, state->fd);
return 0;
}
@@ -768,9 +795,8 @@ service_tcp_fd(krb5_context context, struct conn_state *conn,
ssize_t nwritten, nread;
SOCKET_WRITEV_TEMP tmp;
- /* Check for a socket exception or readable data before we expect it. */
- if (ssflags & SSF_EXCEPTION ||
- ((ssflags & SSF_READ) && conn->state != READING))
+ /* Check for a socket exception. */
+ if (ssflags & SSF_EXCEPTION)
goto kill_conn;
switch (conn->state) {
@@ -810,7 +836,7 @@ service_tcp_fd(krb5_context context, struct conn_state *conn,
}
if (conn->x.out.sg_count == 0) {
/* Done writing, switch to reading. */
- cm_unset_write(selstate, conn->fd);
+ cm_read(selstate, conn->fd);
conn->state = READING;
conn->x.in.bufsizebytes_read = 0;
conn->x.in.bufsize = 0;