changeset 4608:3aef2b50ef03 draft

DHCP: Rework checksuming so that the packet isn't touched. Other than setting udp->uh_sum to zero which we need to do to calculate the checksum. Also, the UDP checksum needs to include a pseudo IP header without options and mostly blank. Instead of changing the packet, just checksum a blank object we've filled in with the needed data from the given IP object and use this to start the UDP checksum calculation with. While here, improve the checksum function so it more matches the in_cksum function as noted in RFC 1071 4.1 using 16 byte words.
author Roy Marples <roy@marples.name>
date Wed, 31 Jul 2019 09:39:58 +0100
parents 8e54887526a6
children 636120b95e27 a19f5fd758f6
files src/dhcp.c
diffstat 1 files changed, 28 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/src/dhcp.c	Tue Jul 30 11:04:24 2019 +0100
+++ b/src/dhcp.c	Wed Jul 31 09:39:58 2019 +0100
@@ -1584,24 +1584,24 @@
 }
 
 static uint16_t
-checksum(const void *data, size_t len)
+in_cksum(void *data, size_t len, uint32_t *isum)
 {
-	const uint8_t *addr = data;
-	uint32_t sum = 0;
-
-	while (len > 1) {
-		sum += (uint32_t)(addr[0] * 256 + addr[1]);
-		addr += 2;
-		len -= 2;
-	}
+	const uint16_t *word = data;
+	uint32_t sum = isum != NULL ? *isum : 0;
+
+	for (; len > 1; len -= sizeof(*word))
+		sum += *word++;
 
 	if (len == 1)
-		sum += (uint32_t)(*addr * 256);
+		sum += *(const uint8_t *)word;
+
+	if (isum != NULL)
+		*isum = sum;
 
 	sum = (sum >> 16) + (sum & 0xffff);
 	sum += (sum >> 16);
 
-	return (uint16_t)~htons((uint16_t)sum);
+	return (uint16_t)~sum;
 }
 
 static struct bootp_pkt *
@@ -1639,14 +1639,14 @@
 	udp->uh_dport = htons(BOOTPS);
 	udp->uh_ulen = htons((uint16_t)(sizeof(*udp) + length));
 	ip->ip_len = udp->uh_ulen;
-	udp->uh_sum = checksum(udpp, sizeof(*ip) +  sizeof(*udp) + length);
+	udp->uh_sum = in_cksum(udpp, sizeof(*ip) + sizeof(*udp) + length, NULL);
 
 	ip->ip_v = IPVERSION;
 	ip->ip_hl = sizeof(*ip) >> 2;
 	ip->ip_id = (uint16_t)arc4random_uniform(UINT16_MAX);
 	ip->ip_ttl = IPDEFTTL;
 	ip->ip_len = htons((uint16_t)(sizeof(*ip) + sizeof(*udp) + length));
-	ip->ip_sum = checksum(ip, sizeof(*ip));
+	ip->ip_sum = in_cksum(ip, sizeof(*ip), NULL);
 
 	*sz = sizeof(*ip) + sizeof(*udp) + length;
 	return udpp;
@@ -3236,10 +3236,15 @@
 	unsigned int flags)
 {
 	struct ip *ip = packet;
-	char ip_hlv = *(char *)ip;
+	struct ip pseudo_ip = {
+		.ip_p = IPPROTO_UDP,
+		.ip_src = ip->ip_src,
+		.ip_dst = ip->ip_dst
+	};
 	size_t ip_hlen;
 	uint16_t ip_len, uh_sum;
 	struct udphdr *udp;
+	uint32_t csum;
 
 	if (plen < sizeof(*ip)) {
 		if (from != NULL)
@@ -3252,13 +3257,13 @@
 		from->s_addr = ip->ip_src.s_addr;
 
 	ip_hlen = (size_t)ip->ip_hl * 4;
-	if (checksum(ip, ip_hlen) != 0) {
+	if (in_cksum(ip, ip_hlen, NULL) != 0) {
 		errno = EINVAL;
 		return -1;
 	}
 
+	/* Check we have a payload */
 	ip_len = ntohs(ip->ip_len);
-	/* Check we have a payload */
 	if (ip_len <= ip_hlen + sizeof(*udp)) {
 		errno = ERANGE;
 		return -1;
@@ -3272,28 +3277,21 @@
 	if (flags & BPF_PARTIALCSUM)
 		return 0;
 
+	/* UDP checksum is based on a pseudo IP header alongside
+	 * the UDP header and payload. */
 	udp = (struct udphdr *)((char *)ip + ip_hlen);
 	if (udp->uh_sum == 0)
 		return 0;
+
 	uh_sum = udp->uh_sum;
-
-	/* This does scribble on the packet, but at this point
-	 * we don't care to keep it. */
 	udp->uh_sum = 0;
-	ip->ip_hl = 0;
-	ip->ip_v = 0;
-	ip->ip_tos = 0;
-	ip->ip_len = udp->uh_ulen;
-	ip->ip_id = 0;
-	ip->ip_off = 0;
-	ip->ip_ttl = 0;
-	ip->ip_sum = 0;
-	if (checksum(packet, ip_len) != uh_sum) {
+	pseudo_ip.ip_len = udp->uh_ulen;
+	csum = 0;
+	in_cksum(&pseudo_ip, sizeof(pseudo_ip), &csum);
+	if (in_cksum(udp, ntohs(udp->uh_ulen), &csum) != uh_sum) {
 		errno = EINVAL;
 		return -1;
 	}
-	*(char *)ip = ip_hlv;
-	ip->ip_len = htons(ip_len);
 
 	return 0;
 }