Blob Blame History Raw
From bc7c550c444210aa8decf98ac0c1dcd051fcc532 Mon Sep 17 00:00:00 2001
From: Filipe Brandenburger <filbranden@google.com>
Date: Tue, 24 Jul 2018 18:46:01 -0700
Subject: [PATCH] socket-util: Introduce send_one_fd_iov() and
 receive_one_fd_iov()

These take a struct iovec to send data together with the passed FD.

The receive function returns the FD through an output argument. In case data is
received, but no FD is passed, the receive function will set the output
argument to -1 explicitly.

Update code in dynamic-user to use the new helpers.

(cherry picked from commit d34673ecb825aa9ecf6958b0caab792f5061c56a)

Resolves: #1683319
---
 src/basic/socket-util.c | 97 ++++++++++++++++++++++++++++++++---------
 src/basic/socket-util.h | 10 ++++-
 src/core/dynamic-user.c | 57 ++----------------------
 3 files changed, 90 insertions(+), 74 deletions(-)

diff --git a/src/basic/socket-util.c b/src/basic/socket-util.c
index 3f90a81d35..986bc6e67f 100644
--- a/src/basic/socket-util.c
+++ b/src/basic/socket-util.c
@@ -1011,9 +1011,10 @@ int getpeergroups(int fd, gid_t **ret) {
         return (int) n;
 }
 
-int send_one_fd_sa(
+ssize_t send_one_fd_iov_sa(
                 int transport_fd,
                 int fd,
+                struct iovec *iov, size_t iovlen,
                 const struct sockaddr *sa, socklen_t len,
                 int flags) {
 
@@ -1024,28 +1025,58 @@ int send_one_fd_sa(
         struct msghdr mh = {
                 .msg_name = (struct sockaddr*) sa,
                 .msg_namelen = len,
-                .msg_control = &control,
-                .msg_controllen = sizeof(control),
+                .msg_iov = iov,
+                .msg_iovlen = iovlen,
         };
-        struct cmsghdr *cmsg;
+        ssize_t k;
 
         assert(transport_fd >= 0);
-        assert(fd >= 0);
 
-        cmsg = CMSG_FIRSTHDR(&mh);
-        cmsg->cmsg_level = SOL_SOCKET;
-        cmsg->cmsg_type = SCM_RIGHTS;
-        cmsg->cmsg_len = CMSG_LEN(sizeof(int));
-        memcpy(CMSG_DATA(cmsg), &fd, sizeof(int));
+        /*
+         * We need either an FD or data to send.
+         * If there's nothing, return an error.
+         */
+        if (fd < 0 && !iov)
+                return -EINVAL;
 
-        mh.msg_controllen = CMSG_SPACE(sizeof(int));
-        if (sendmsg(transport_fd, &mh, MSG_NOSIGNAL | flags) < 0)
-                return -errno;
+        if (fd >= 0) {
+                struct cmsghdr *cmsg;
 
-        return 0;
+                mh.msg_control = &control;
+                mh.msg_controllen = sizeof(control);
+
+                cmsg = CMSG_FIRSTHDR(&mh);
+                cmsg->cmsg_level = SOL_SOCKET;
+                cmsg->cmsg_type = SCM_RIGHTS;
+                cmsg->cmsg_len = CMSG_LEN(sizeof(int));
+                memcpy(CMSG_DATA(cmsg), &fd, sizeof(int));
+
+                mh.msg_controllen = CMSG_SPACE(sizeof(int));
+        }
+        k = sendmsg(transport_fd, &mh, MSG_NOSIGNAL | flags);
+        if (k < 0)
+                return (ssize_t) -errno;
+
+        return k;
 }
 
-int receive_one_fd(int transport_fd, int flags) {
+int send_one_fd_sa(
+                int transport_fd,
+                int fd,
+                const struct sockaddr *sa, socklen_t len,
+                int flags) {
+
+        assert(fd >= 0);
+
+        return (int) send_one_fd_iov_sa(transport_fd, fd, NULL, 0, sa, len, flags);
+}
+
+ssize_t receive_one_fd_iov(
+                int transport_fd,
+                struct iovec *iov, size_t iovlen,
+                int flags,
+                int *ret_fd) {
+
         union {
                 struct cmsghdr cmsghdr;
                 uint8_t buf[CMSG_SPACE(sizeof(int))];
@@ -1053,10 +1084,14 @@ int receive_one_fd(int transport_fd, int flags) {
         struct msghdr mh = {
                 .msg_control = &control,
                 .msg_controllen = sizeof(control),
+                .msg_iov = iov,
+                .msg_iovlen = iovlen,
         };
         struct cmsghdr *cmsg, *found = NULL;
+        ssize_t k;
 
         assert(transport_fd >= 0);
+        assert(ret_fd);
 
         /*
          * Receive a single FD via @transport_fd. We don't care for
@@ -1066,8 +1101,9 @@ int receive_one_fd(int transport_fd, int flags) {
          * combination with send_one_fd().
          */
 
-        if (recvmsg(transport_fd, &mh, MSG_CMSG_CLOEXEC | flags) < 0)
-                return -errno;
+        k = recvmsg(transport_fd, &mh, MSG_CMSG_CLOEXEC | flags);
+        if (k < 0)
+                return (ssize_t) -errno;
 
         CMSG_FOREACH(cmsg, &mh) {
                 if (cmsg->cmsg_level == SOL_SOCKET &&
@@ -1079,12 +1115,33 @@ int receive_one_fd(int transport_fd, int flags) {
                 }
         }
 
-        if (!found) {
+        if (!found)
                 cmsg_close_all(&mh);
+
+        /* If didn't receive an FD or any data, return an error. */
+        if (k == 0 && !found)
                 return -EIO;
-        }
 
-        return *(int*) CMSG_DATA(found);
+        if (found)
+                *ret_fd = *(int*) CMSG_DATA(found);
+        else
+                *ret_fd = -1;
+
+        return k;
+}
+
+int receive_one_fd(int transport_fd, int flags) {
+        int fd;
+        ssize_t k;
+
+        k = receive_one_fd_iov(transport_fd, NULL, 0, flags, &fd);
+        if (k == 0)
+                return fd;
+
+        /* k must be negative, since receive_one_fd_iov() only returns
+         * a positive value if data was received through the iov. */
+        assert(k < 0);
+        return (int) k;
 }
 
 ssize_t next_datagram_size_fd(int fd) {
diff --git a/src/basic/socket-util.h b/src/basic/socket-util.h
index 8e23cf2dbd..82781a0de1 100644
--- a/src/basic/socket-util.h
+++ b/src/basic/socket-util.h
@@ -130,11 +130,19 @@ int getpeercred(int fd, struct ucred *ucred);
 int getpeersec(int fd, char **ret);
 int getpeergroups(int fd, gid_t **ret);
 
+ssize_t send_one_fd_iov_sa(
+                int transport_fd,
+                int fd,
+                struct iovec *iov, size_t iovlen,
+                const struct sockaddr *sa, socklen_t len,
+                int flags);
 int send_one_fd_sa(int transport_fd,
                    int fd,
                    const struct sockaddr *sa, socklen_t len,
                    int flags);
-#define send_one_fd(transport_fd, fd, flags) send_one_fd_sa(transport_fd, fd, NULL, 0, flags)
+#define send_one_fd_iov(transport_fd, fd, iov, iovlen, flags) send_one_fd_iov_sa(transport_fd, fd, iov, iovlen, NULL, 0, flags)
+#define send_one_fd(transport_fd, fd, flags) send_one_fd_iov_sa(transport_fd, fd, NULL, 0, NULL, 0, flags)
+ssize_t receive_one_fd_iov(int transport_fd, struct iovec *iov, size_t iovlen, int flags, int *ret_fd);
 int receive_one_fd(int transport_fd, int flags);
 
 ssize_t next_datagram_size_fd(int fd);
diff --git a/src/core/dynamic-user.c b/src/core/dynamic-user.c
index 7c5111ddf6..021fd93a76 100644
--- a/src/core/dynamic-user.c
+++ b/src/core/dynamic-user.c
@@ -312,20 +312,8 @@ static int pick_uid(char **suggested_paths, const char *name, uid_t *ret_uid) {
 static int dynamic_user_pop(DynamicUser *d, uid_t *ret_uid, int *ret_lock_fd) {
         uid_t uid = UID_INVALID;
         struct iovec iov = IOVEC_INIT(&uid, sizeof(uid));
-        union {
-                struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(int))];
-        } control = {};
-        struct msghdr mh = {
-                .msg_control = &control,
-                .msg_controllen = sizeof(control),
-                .msg_iov = &iov,
-                .msg_iovlen = 1,
-        };
-        struct cmsghdr *cmsg;
-
+        int lock_fd;
         ssize_t k;
-        int lock_fd = -1;
 
         assert(d);
         assert(ret_uid);
@@ -334,15 +322,9 @@ static int dynamic_user_pop(DynamicUser *d, uid_t *ret_uid, int *ret_lock_fd) {
         /* Read the UID and lock fd that is stored in the storage AF_UNIX socket. This should be called with the lock
          * on the socket taken. */
 
-        k = recvmsg(d->storage_socket[0], &mh, MSG_DONTWAIT|MSG_CMSG_CLOEXEC);
+        k = receive_one_fd_iov(d->storage_socket[0], &iov, 1, MSG_DONTWAIT, &lock_fd);
         if (k < 0)
-                return -errno;
-
-        cmsg = cmsg_find(&mh, SOL_SOCKET, SCM_RIGHTS, CMSG_LEN(sizeof(int)));
-        if (cmsg)
-                lock_fd = *(int*) CMSG_DATA(cmsg);
-        else
-                cmsg_close_all(&mh); /* just in case... */
+                return (int) k;
 
         *ret_uid = uid;
         *ret_lock_fd = lock_fd;
@@ -352,42 +334,11 @@ static int dynamic_user_pop(DynamicUser *d, uid_t *ret_uid, int *ret_lock_fd) {
 
 static int dynamic_user_push(DynamicUser *d, uid_t uid, int lock_fd) {
         struct iovec iov = IOVEC_INIT(&uid, sizeof(uid));
-        union {
-                struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(int))];
-        } control = {};
-        struct msghdr mh = {
-                .msg_control = &control,
-                .msg_controllen = sizeof(control),
-                .msg_iov = &iov,
-                .msg_iovlen = 1,
-        };
-        ssize_t k;
 
         assert(d);
 
         /* Store the UID and lock_fd in the storage socket. This should be called with the socket pair lock taken. */
-
-        if (lock_fd >= 0) {
-                struct cmsghdr *cmsg;
-
-                cmsg = CMSG_FIRSTHDR(&mh);
-                cmsg->cmsg_level = SOL_SOCKET;
-                cmsg->cmsg_type = SCM_RIGHTS;
-                cmsg->cmsg_len = CMSG_LEN(sizeof(int));
-                memcpy(CMSG_DATA(cmsg), &lock_fd, sizeof(int));
-
-                mh.msg_controllen = CMSG_SPACE(sizeof(int));
-        } else {
-                mh.msg_control = NULL;
-                mh.msg_controllen = 0;
-        }
-
-        k = sendmsg(d->storage_socket[1], &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
-        if (k < 0)
-                return -errno;
-
-        return 0;
+        return send_one_fd_iov(d->storage_socket[1], lock_fd, &iov, 1, MSG_DONTWAIT);
 }
 
 static void unlink_uid_lock(int lock_fd, uid_t uid, const char *name) {