changeset 5386:b28b0962bfe2 draft

Linux: keep the generic netlink socket around to get ssid with privsep While here, improve our reading of netlink(7) and terminate on either ERROR or DONE. If neither are in the message, read again unless it's the link receiving socket. Also, only callback if this is the sequence number expected.
author Roy Marples <roy@marples.name>
date Mon, 22 Jun 2020 21:56:16 +0100
parents 0397759ce185
children 08cbef5c8e9e
files src/if-linux.c
diffstat 1 files changed, 48 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/src/if-linux.c	Mon Jun 22 17:31:58 2020 +0100
+++ b/src/if-linux.c	Mon Jun 22 21:56:16 2020 +0100
@@ -130,6 +130,7 @@
 
 struct priv {
 	int route_fd;
+	int generic_fd;
 	uint32_t route_pid;
 };
 
@@ -414,6 +415,12 @@
 	if (getsockname(priv->route_fd, (struct sockaddr *)&snl, &len) == -1)
 		return -1;
 	priv->route_pid = snl.nl_pid;
+
+	memset(&snl, 0, sizeof(snl));
+	priv->generic_fd = if_linksocket(&snl, NETLINK_GENERIC, 0);
+	if (priv->generic_fd == -1)
+		return -1;
+
 	return 0;
 }
 
@@ -425,6 +432,7 @@
 	if (ctx->priv != NULL) {
 		priv = (struct priv *)ctx->priv;
 		close(priv->route_fd);
+		close(priv->generic_fd);
 	}
 }
 
@@ -465,26 +473,27 @@
 	};
 	ssize_t len;
 	struct nlmsghdr *nlm;
-	int r;
+	int r = 0;
 	unsigned int again;
+	bool terminated;
 
 recv_again:
-	if ((len = recvmsg(fd, &msg, flags)) == -1)
-		return -1;
-	if (len == 0)
-		return 0;
+	len = recvmsg(fd, &msg, flags);
+	if (len == -1 || len == 0)
+		return (int)len;
 
 	/* Check sender */
 	if (msg.msg_namelen != sizeof(nladdr)) {
 		errno = EINVAL;
 		return -1;
 	}
+
 	/* Ignore message if it is not from kernel */
 	if (nladdr.nl_pid != 0)
 		return 0;
 
-	r = 0;
 	again = 0;
+	terminated = false;
 	for (nlm = iov->iov_base;
 	     nlm && NLMSG_OK(nlm, (size_t)len);
 	     nlm = NLMSG_NEXT(nlm, len))
@@ -492,6 +501,7 @@
 		again = (nlm->nlmsg_flags & NLM_F_MULTI);
 		if (nlm->nlmsg_type == NLMSG_NOOP)
 			continue;
+
 		if (nlm->nlmsg_type == NLMSG_ERROR) {
 			struct nlmsgerr *err;
 
@@ -504,17 +514,21 @@
 				errno = -err->error;
 				return -1;
 			}
+			again = 0;
+			terminated = true;
 			break;
 		}
 		if (nlm->nlmsg_type == NLMSG_DONE) {
 			again = 0;
+			terminated = true;
 			break;
 		}
-		if (cb != NULL && (r = cb(ctx, cbarg, nlm)) != 0)
-			break;
+		if (cb != NULL &&
+		   (nlm->nlmsg_seq == (uint32_t)ctx->seq || fd == ctx->link_fd))
+			r = cb(ctx, cbarg, nlm);
 	}
 
-	if (r == 0 && again)
+	if ((again || !terminated) && (ctx != NULL && ctx->link_fd != fd))
 		goto recv_again;
 
 	return r;
@@ -982,16 +996,19 @@
 if_sendnetlink(struct dhcpcd_ctx *ctx, int protocol, struct nlmsghdr *hdr,
     int (*cb)(struct dhcpcd_ctx *, void *, struct nlmsghdr *), void *cbarg)
 {
-	int s, r;
+	int s;
 	struct sockaddr_nl snl = { .nl_family = AF_NETLINK };
 	struct iovec iov = { .iov_base = hdr, .iov_len = hdr->nlmsg_len };
 	struct msghdr msg = {
 	    .msg_name = &snl, .msg_namelen = sizeof(snl),
 	    .msg_iov = &iov, .msg_iovlen = 1
 	};
-	bool use_rfd;
-
-	use_rfd = (protocol == NETLINK_ROUTE && hdr->nlmsg_type != RTM_GETADDR);
+	struct priv *priv = (struct priv *)ctx->priv;
+	unsigned char buf[16 * 1024];
+	struct iovec riov = {
+		.iov_base = buf,
+		.iov_len = sizeof(buf),
+	};
 
 	/* Request a reply */
 	hdr->nlmsg_flags |= NLM_F_ACK;
@@ -1002,13 +1019,16 @@
 		return (int)ps_root_sendnetlink(ctx, protocol, &msg);
 #endif
 
-	if (use_rfd) {
-		struct priv *priv = (struct priv *)ctx->priv;
-
-		s = priv->route_fd;
-	} else {
-		if ((s = if_linksocket(&snl, protocol, 0)) == -1)
-			return -1;
+	switch (protocol) {
+	case NETLINK_ROUTE:
+		if (hdr->nlmsg_type != RTM_GETADDR) {
+			s = priv->route_fd;
+			break;
+		}
+		/* FALLTHROUGH */
+	case NETLINK_GENERIC:
+		s = priv->generic_fd;
+#if 0
 #ifdef NETLINK_GET_STRICT_CHK
 		if (hdr->nlmsg_type == RTM_GETADDR) {
 			int on = 1;
@@ -1018,22 +1038,17 @@
 				logerr("%s: NETLINK_GET_STRICT_CHK", __func__);
 		}
 #endif
+#endif
+		break;
+	default:
+		errno = EINVAL;
+		return -1;
 	}
 
-	if (sendmsg(s, &msg, 0) != -1) {
-		unsigned char buf[16 * 1024];
-		struct iovec riov = {
-			.iov_base = buf,
-			.iov_len = sizeof(buf),
-		};
+	if (sendmsg(s, &msg, 0) == -1)
+		return -1;
 
-		r = if_getnetlink(ctx, &riov, s, 0, cb, cbarg);
-	} else
-		r = -1;
-
-	if (!use_rfd)
-		close(s);
-	return r;
+	return if_getnetlink(ctx, &riov, s, 0, cb, cbarg);
 }
 
 #define NLMSG_TAIL(nmsg)						\