diff --git a/pkgs/ch-proxy/sendfd.c b/pkgs/ch-proxy/sendfd.c index c649316..b20e284 100644 --- a/pkgs/ch-proxy/sendfd.c +++ b/pkgs/ch-proxy/sendfd.c @@ -36,3 +36,39 @@ ssize_t send_fd(int dst_fd, int fd, const struct iovec *iov) { return (sendmsg(dst_fd, &msg, 0)); } + +int recv_fd(int sock, int flags) { + int out = -1; + + struct msghdr msg = { 0 }; + struct cmsghdr *cmsg = NULL; + struct iovec iov = { 0 }; + char dummy = 0; + + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + iov.iov_base = &dummy; + iov.iov_len = sizeof(dummy); + + union { + struct cmsghdr align; + char buf[CMSG_SPACE(sizeof(int))]; + } u; + + msg.msg_control = u.buf; + msg.msg_controllen = sizeof(u.buf); + + int bytes = 0; + if ((bytes = recvmsg(sock, &msg, flags)) < 0) { + perror("recv_fd: recvmsg"); + return -1; + } + for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level != SOL_SOCKET) { continue; } + if (cmsg->cmsg_type != SCM_RIGHTS) { continue; } + if (CMSG_LEN(cmsg) < sizeof(out)) { continue; } + out = *(int*)CMSG_DATA(cmsg); + } + return out; +} diff --git a/pkgs/ch-proxy/sendfd.h b/pkgs/ch-proxy/sendfd.h index 8c99389..fc1d2f8 100644 --- a/pkgs/ch-proxy/sendfd.h +++ b/pkgs/ch-proxy/sendfd.h @@ -5,7 +5,23 @@ #include /* ssize_t */ #include /* iovec */ -ssize_t send_fd(int dst_fd, int fd, const struct iovec *); + +/* send_fd(chanFd, fd, *iov) + * + * chanFd: fd to sendmsg over; + * fd: fd to send; + * iov: extra data to send or NULL; + * + * returns: result of sendmsg, + * i.e. the number of bytes sent */ +ssize_t send_fd(int chanFd, int fd, const struct iovec *); + +/* recv_fd(chanFd, flags) + * + * chanFd: fd to recvmsg from; + * flags: recvmsg flags e.g. 0, or MSG_CMSG_CLOEXEC? + * + * returns: the received fd or -1 */ +int recv_fd(int chanFd, int flags); #endif /* _CH_PROXY_SENFD */ - diff --git a/pkgs/taps/main.c b/pkgs/taps/main.c index 82a02cd..05edb64 100644 --- a/pkgs/taps/main.c +++ b/pkgs/taps/main.c @@ -10,10 +10,11 @@ #include #include #include /* open, O_NONBLOCK, &c */ +#include #include #include #include - +#include #define __UAPI_DEF_IF_IFNAMSIZ 1 #include @@ -25,13 +26,19 @@ #define SUN_PATH_SZ 108 #define N_CONNS 16 -#define IFR_FLAGS_ALLOWED (IFF_NO_PI | IFF_TAP | IFF_TUN | IFF_VNET_HDR | IFF_MULTI_QUEUE) -#define IFR_FLAGS_DEFAULT (IFF_NO_PI | IFF_TAP | IFF_VNET_HDR) +char *TEMP_PATHS[1024] = { 0 }; +int LAST_TEMP_PATH = -1; -#define DO_OR_DIE(expr) DO_OR_DIE_X((expr) == 0) -#define DO_OR_DIE_X(expr) \ +#define IFR_FLAGS_ALLOWED (IFF_NO_PI | IFF_TAP | IFF_TUN | IFF_VNET_HDR | IFF_MULTI_QUEUE | IFF_PERSIST) +#define IFR_FLAGS_DEFAULT (IFF_NO_PI | IFF_TAP | IFF_VNET_HDR | IFF_PERSIST) + +#define PTR_OR_DIE(expr) TRUE_OR_DIE((expr) != NULL) +#define DO_OR_DIE(expr) TRUE_OR_DIE((expr) != -1) +#define TRUE_OR_DIE(expr, ...) TRUE_OR_(EXIT_FAILURE, expr, __VA_ARGS__) +#define TRUE_OR_WARN(expr, ...) TRUE_OR_(0, expr, __VA_ARGS__) +#define TRUE_OR_(status, expr, ...) \ do if (!(expr)) { \ - error(EXIT_FAILURE, errno, ("Failed assertion: " #expr)); \ + error(status, errno, "Failed assertion: " #expr "." __VA_ARGS__); \ } while(false) struct allow_pattern { @@ -72,38 +79,34 @@ bool match_mask(const char *test_addr, const char *expected_addr, const char *ma * `linux/Documentation/networking/tuntap.rst`. * * ifrFlags: IFF_TUN - TUN device (no Ethernet headers) - * IFF_TAP - TAP device + * IFF_TAP - TAP device * - * IFF_NO_PI - Do not provide packet information + * IFF_NO_PI - Do not provide packet information */ int tuntap_alloc(char *dev, short openFlags, short ifrFlags, int *out_fd) { - struct ifreq ifr; - int fd, err; + struct ifreq ifr = { 0 }; + int fd = -1, err = 0; - // if ((fd = open("/dev/net/tun", O_RDWR)) < 0) { - // return tun_alloc_old(dev); - // } + DO_OR_DIE(fd = open("/dev/net/tun", openFlags)); - DO_OR_DIE_X((fd = open("/dev/net/tun", openFlags)) >= 0); - - memset(&ifr, 0, sizeof(ifr)); - - if (*dev) { + if (dev != NULL) { int devLen = strlen(dev); if (devLen >= IFNAMSIZ) { /* If client requests a name, we do want the entire name to fit */ errno = EINVAL; return EINVAL; } - strncpy(ifr.ifr_name, dev, IFNAMSIZ); + strncpy(ifr.ifr_name, dev, IFNAMSIZ - 1); } + ifr.ifr_flags = ifrFlags; - if ((err = ioctl(fd, TUNSETIFF, (void *)&ifr)) < 0) { + TRUE_OR_WARN((err = ioctl(fd, TUNSETIFF, (void *)&ifr)) == 0); + if (err != 0) { close(fd); return err; } - strcpy(dev, ifr.ifr_name); + strncpy(dev, ifr.ifr_name, IFNAMSIZ); *out_fd = fd; return 0; } @@ -111,18 +114,20 @@ int tuntap_alloc(char *dev, short openFlags, short ifrFlags, int *out_fd) { int acceptRequests(const char *requestsPath, const struct allow_patterns *patterns) { int listener; struct sockaddr_un addr; - const bool t = 1; + const int t = 1; DO_OR_DIE(listener = socket(AF_UNIX, SOCK_SEQPACKET, 0)); - DO_OR_DIE(setsockopt(listener, SOL_SOCKET, SO_PASSCRED, &t, 1) != 0); + DO_OR_DIE(setsockopt(listener, SOL_SOCKET, SO_PASSCRED, &t, sizeof(t))); addr.sun_family = AF_UNIX; - strncpy(addr.sun_path, requestsPath, SUN_PATH_SZ); - DO_OR_DIE (bind(listener, &addr, sizeof(addr)) == -1); + strncpy(addr.sun_path, requestsPath, SUN_PATH_SZ - 1); + DO_OR_DIE (bind(listener, &addr, sizeof(addr))); + PTR_OR_DIE(TEMP_PATHS[++LAST_TEMP_PATH] = strdup(requestsPath)); DO_OR_DIE(listen(listener, N_CONNS)); for (;;) { + /* Already changed my mind about looking at ucred, but keeping the code around for now */ int sock = -1; struct ucred cred = { 0 }; struct msghdr msg = { 0 }; @@ -136,13 +141,14 @@ int acceptRequests(const char *requestsPath, const struct allow_patterns *patter iov.iov_base = &req; iov.iov_len = sizeof(struct tap_request); - DO_OR_DIE(sock = accept(listener, NULL, NULL)); + DO_OR_DIE((sock = accept(listener, NULL, NULL))); - DO_OR_DIE_X(recvmsg(sock, &msg, 0) > 0); + TRUE_OR_DIE(recvmsg(sock, &msg, 0) > 0); + req.name[IFNAMSIZ] = 0; for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (!(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS)) { - continue; - } + if (cmsg->cmsg_level != SOL_SOCKET) { continue; } + if (cmsg->cmsg_type != SCM_CREDENTIALS) { continue; } + if (CMSG_LEN(cmsg) < sizeof(struct ucred)) { continue; } memcpy(&cred, CMSG_DATA(cmsg), sizeof(struct ucred)); break; } @@ -162,12 +168,13 @@ int acceptRequests(const char *requestsPath, const struct allow_patterns *patter if (!allowed) { reply.status = AUTH_ERROR; } if (allowed) { /* O_CLOEXEC? */ - int fd = 0; - DO_OR_DIE(tuntap_alloc(req.name, O_RDWR | O_NONBLOCK, req.ifrFlags, &fd)); + int fd = -1; + TRUE_OR_DIE(tuntap_alloc(req.name, O_RDWR | O_NONBLOCK, req.ifrFlags, &fd) == 0); struct iovec iov = { 0 }; iov.iov_base = &reply; iov.iov_len = sizeof(struct tap_reply); - DO_OR_DIE_X(send_fd(sock, fd, &iov) > 0); + TRUE_OR_DIE(send_fd(sock, fd, &iov) > 0); + close(fd); } close(sock); } @@ -185,7 +192,8 @@ struct allow_patterns parsePatterns(const char *raw) { if (start < i) { ++nPatterns; } } - struct allow_pattern *patterns = calloc(nPatterns, sizeof(struct allow_pattern)); + struct allow_pattern *patterns = NULL; + PTR_OR_DIE(patterns = calloc(nPatterns, sizeof(struct allow_pattern))); int iPattern = 0; for (int i = 0; i < rawLen; ++i) { @@ -194,8 +202,8 @@ struct allow_patterns parsePatterns(const char *raw) { { const int start = i; for (; i < rawLen && !isspace(raw[i]); ++i) { } - if (i < rawLen) { - patterns[iPattern].name = strndup(&raw[start], i - start); + if (start < i) { + PTR_OR_DIE(patterns[iPattern].name = strndup(&raw[start], i - start)); iPattern += 1; } } @@ -207,37 +215,109 @@ struct allow_patterns parsePatterns(const char *raw) { return out; } +int get(const char *servePath, const char *ifname, short ifrFlags) { + /* TODO: sock: move out */ + int sock; + struct sockaddr_un addr; + + DO_OR_DIE(sock = socket(AF_UNIX, SOCK_SEQPACKET, 0)); + + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, servePath, SUN_PATH_SZ - 1); + DO_OR_DIE (connect(sock, &addr, sizeof(addr))); + + struct msghdr msg = { 0 }; + struct cmsghdr *cmsg = NULL; + struct iovec iov = { 0 }; + struct tap_request req = { 0 }; + strncpy(req.name, ifname, IFNAMSIZ - 1); + req.ifrFlags = ifrFlags; + + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + iov.iov_base = &req; + iov.iov_len = sizeof(struct tap_request); + + TRUE_OR_DIE(sendmsg(sock, &msg, 0) > 0); + + int tunFd = -1; + DO_OR_DIE(tunFd = recv_fd(sock, 0)); + close(sock); + return tunFd; +} + +void cleanup(int signo, siginfo_t *info, void *_context) { + for (int i = 0; i <= LAST_TEMP_PATH; ++i) { + TRUE_OR_DIE(unlink(TEMP_PATHS[i]) != -1 || errno == ENOENT); + } + if (signo == SIGINT) { + exit(EXIT_SUCCESS); + } + errx(EXIT_FAILURE, "Exiting with signal %d", signo); +} + int main(int argc, char **argv) { + struct sigaction act = { 0 }; + act.sa_flags = SA_SIGINFO; + act.sa_sigaction = cleanup; + DO_OR_DIE(sigaction(SIGINT, &act, NULL)); + DO_OR_DIE(sigaction(SIGSEGV, &act, NULL)); + bool cmdServe = false; - bool cmdGet = false; + bool cmdPass = false; + char *ifname = "vt%d"; char **rest = argv + 1; + char **end = argv + argc; - DO_OR_DIE_X(argc > 1); + TRUE_OR_DIE(argc > 1); if (strcmp(rest[0], "serve") == 0) { cmdServe = true; ++rest; - } else if (strcmp(rest[0], "get") == 0) { - cmdGet = true; + } else if (strcmp(rest[0], "pass") == 0) { + cmdPass = true; ++rest; + for (; rest != end && rest[0][0] == '-'; ++rest) { + if (strcmp(rest[0], "--")) { break; } + else if (strncmp(rest[0], "--ifname=", sizeof("--ifname="))) { + ifname = rest[0] + sizeof("--ifname="); + } + } } else { error(EINVAL, EINVAL, "no subcommand \"%s\"", rest[0]); } + int nextArgc = argc - (rest - argv); + char * const* nextArgv = rest; + const char *patternsRaw = secure_getenv("TAPS_ALLOW"); + if (patternsRaw == NULL) { + patternsRaw = "*"; + } + struct allow_patterns patterns = { 0 }; - if (cmdServe && patternsRaw != NULL) { - patterns = parsePatterns(patternsRaw); - DO_OR_DIE_X(patterns.patterns != NULL); + if (cmdServe) { + PTR_OR_DIE((patterns = parsePatterns(patternsRaw)).patterns); } const char *servePath = secure_getenv("TAPS_SOCK"); + if (servePath == NULL) { + servePath = "taps.sock"; + } if (cmdServe) { - error(patterns.patterns == NULL, EINVAL, "TAPS_ALLOW"); acceptRequests(servePath, &patterns); - } else if (cmdGet) { - error(ENOSYS, ENOSYS, "get"); + } else if (cmdPass) { + TRUE_OR_DIE(nextArgc > 0); + int fd = -1; + DO_OR_DIE(fd = get(servePath, ifname, 0)); + if (fd != 3) { + DO_OR_DIE(dup2(fd, 3)); + close(fd); + fd = 3; + } + DO_OR_DIE(execvp(nextArgv[0], nextArgv)); } else { error(EINVAL, EINVAL, "subcommand args"); }