commit 98c6998116e33f9f34b798682e0695f4166bd86d
Author: Simon Kelley <simon@thekelleys.org.uk>
Date: Mon Mar 2 17:10:25 2020 +0000
Optimise closing file descriptors.
Dnsmasq needs to close all the file descriptors it inherits, for security
reasons. This is traditionally done by calling close() on every possible
file descriptor (most of which won't be open.) On big servers where
"every possible file descriptor" is a rather large set, this gets
rather slow, so we use the /proc/<pid>/fd directory to get a list
of the fds which are acually open.
This only works on Linux. On other platforms, and on Linux systems
without a /proc filesystem, we fall back to the old way.
diff --git a/src/dnsmasq.c b/src/dnsmasq.c
index 573aac0..10f19ea 100644
--- a/src/dnsmasq.c
+++ b/src/dnsmasq.c
@@ -138,20 +138,18 @@ int main (int argc, char **argv)
}
#endif
- /* Close any file descriptors we inherited apart from std{in|out|err}
-
- Ensure that at least stdin, stdout and stderr (fd 0, 1, 2) exist,
+ /* Ensure that at least stdin, stdout and stderr (fd 0, 1, 2) exist,
otherwise file descriptors we create can end up being 0, 1, or 2
and then get accidentally closed later when we make 0, 1, and 2
open to /dev/null. Normally we'll be started with 0, 1 and 2 open,
but it's not guaranteed. By opening /dev/null three times, we
ensure that we're not using those fds for real stuff. */
- for (i = 0; i < max_fd; i++)
- if (i != STDOUT_FILENO && i != STDERR_FILENO && i != STDIN_FILENO)
- close(i);
- else
- open("/dev/null", O_RDWR);
-
+ for (i = 0; i < 3; i++)
+ open("/dev/null", O_RDWR);
+
+ /* Close any file descriptors we inherited apart from std{in|out|err} */
+ close_fds(max_fd, -1, -1, -1);
+
#ifndef HAVE_LINUX_NETWORK
# if !(defined(IP_RECVDSTADDR) && defined(IP_RECVIF) && defined(IP_SENDSRCADDR))
if (!option_bool(OPT_NOWILD))
diff --git a/src/dnsmasq.h b/src/dnsmasq.h
index 6103eb5..c46bfeb 100644
--- a/src/dnsmasq.h
+++ b/src/dnsmasq.h
@@ -1283,7 +1283,7 @@ int memcmp_masked(unsigned char *a, unsigned char *b, int len,
int expand_buf(struct iovec *iov, size_t size);
char *print_mac(char *buff, unsigned char *mac, int len);
int read_write(int fd, unsigned char *packet, int size, int rw);
-
+void close_fds(long max_fd, int spare1, int spare2, int spare3);
int wildcard_match(const char* wildcard, const char* match);
int wildcard_matchn(const char* wildcard, const char* match, int num);
diff --git a/src/helper.c b/src/helper.c
index 1b260a1..7072cf4 100644
--- a/src/helper.c
+++ b/src/helper.c
@@ -131,12 +131,8 @@ int create_helper(int event_fd, int err_fd, uid_t uid, gid_t gid, long max_fd)
Don't close err_fd, in case the lua-init fails.
Note that we have to do this before lua init
so we don't close any lua fds. */
- for (max_fd--; max_fd >= 0; max_fd--)
- if (max_fd != STDOUT_FILENO && max_fd != STDERR_FILENO &&
- max_fd != STDIN_FILENO && max_fd != pipefd[0] &&
- max_fd != event_fd && max_fd != err_fd)
- close(max_fd);
-
+ close_fds(max_fd, pipefd[0], event_fd, err_fd);
+
#ifdef HAVE_LUASCRIPT
if (daemon->luascript)
{
diff --git a/src/util.c b/src/util.c
index 73bf62a..f058c92 100644
--- a/src/util.c
+++ b/src/util.c
@@ -705,6 +705,47 @@ int read_write(int fd, unsigned char *packet, int size, int rw)
return 1;
}
+/* close all fds except STDIN, STDOUT and STDERR, spare1, spare2 and spare3 */
+void close_fds(long max_fd, int spare1, int spare2, int spare3)
+{
+ /* On Linux, use the /proc/ filesystem to find which files
+ are actually open, rather than iterate over the whole space,
+ for efficiency reasons. If this fails we drop back to the dumb code. */
+#ifdef HAVE_LINUX_NETWORK
+ DIR *d;
+
+ if ((d = opendir("/proc/self/fd")))
+ {
+ struct dirent *de;
+
+ while ((de = readdir(d)))
+ {
+ long fd;
+ char *e = NULL;
+
+ errno = 0;
+ fd = strtol(de->d_name, &e, 10);
+
+ if (errno != 0 || !e || *e || fd == dirfd(d) ||
+ fd == STDOUT_FILENO || fd == STDERR_FILENO || fd == STDIN_FILENO ||
+ fd == spare1 || fd == spare2 || fd == spare3)
+ continue;
+
+ close(fd);
+ }
+
+ closedir(d);
+ return;
+ }
+#endif
+
+ /* fallback, dumb code. */
+ for (max_fd--; max_fd >= 0; max_fd--)
+ if (max_fd != STDOUT_FILENO && max_fd != STDERR_FILENO && max_fd != STDIN_FILENO &&
+ max_fd != spare1 && max_fd != spare2 && max_fd != spare3)
+ close(max_fd);
+}
+
/* Basically match a string value against a wildcard pattern. */
int wildcard_match(const char* wildcard, const char* match)
{