Print this page
11848 Remove STRUIO_IP support from ip_cksum.c

@@ -21,16 +21,20 @@
 /*
  * Copyright 2009 Sun Microsystems, Inc.  All rights reserved.
  * Use is subject to license terms.
  */
 /* Copyright (c) 1990 Mentat Inc. */
+/*
+ * Copyright 2019 Joyent, Inc.
+ */
 
 #include <sys/types.h>
 #include <sys/inttypes.h>
 #include <sys/systm.h>
 #include <sys/stream.h>
 #include <sys/strsun.h>
+#include <sys/strsubr.h>
 #include <sys/debug.h>
 #include <sys/ddi.h>
 #include <sys/vtrace.h>
 #include <inet/sctp_crc32.h>
 #include <inet/ip.h>

@@ -45,342 +49,190 @@
  * Checksum routine for Internet Protocol family headers.
  * This routine is very heavily used in the network
  * code and should be modified for each CPU to be as fast as possible.
  */
 
-#define mp_len(mp) ((mp)->b_wptr - (mp)->b_rptr)
-
 /*
  * Even/Odd checks. Usually it is performed on pointers but may be
  * used on integers as well. uintptr_t is long enough to hold both
  * integer and pointer.
  */
-#define is_odd(p) (((uintptr_t)(p) & 0x1) != 0)
-#define is_even(p) (!is_odd(p))
+#define IS_ODD(p)       (((uintptr_t)(p) & 0x1) != 0)
+#define IS_EVEN(p)      (((uintptr_t)(p) & 0x1) == 0)
 
-
-#ifdef ZC_TEST
 /*
- * Disable the TCP s/w cksum.
- * XXX - This is just a hack for testing purpose. Don't use it for
- * anything else!
  */
-int noswcksum = 0;
+#define HAS_UIOSUM(mp) ((mp)->b_datap->db_struioflag & STRUIO_IP)
+
+#ifdef _LITTLE_ENDIAN
+#define FRAG(ptr) (*(ptr))
+#else
+#define FRAG(ptr) (*(ptr) << 8)
 #endif
+
 /*
+ * Give the compiler a hint to help optimize the code layout
+ */
+#define UNLIKELY(exp) __builtin_expect((exp), 0)
+
+#define FOLD(val) (((val) & 0xFFFF) + ((val) >> 16))
+
+/*
  * Note: this does not ones-complement the result since it is used
- * when computing partial checksums.
- * For nonSTRUIO_IP mblks, assumes mp->b_rptr+offset is 16 bit aligned.
- * For STRUIO_IP mblks, assumes mp->b_datap->db_struiobase is 16 bit aligned.
- *
- * Note: for STRUIO_IP special mblks some data may have been previously
- *       checksumed, this routine will handle additional data prefixed within
- *       an mblk or b_cont (chained) mblk(s). This routine will also handle
- *       suffixed b_cont mblk(s) and data suffixed within an mblk.
+ * when computing partial checksums.  It assumes mp->b_rptr + offset is
+ * 16 bit aligned and a valid offset in mp.
  */
 unsigned int
-ip_cksum(mblk_t *mp, int offset, uint_t sum)
+ip_cksum(mblk_t *mp, int offset, uint_t initial_sum)
 {
-        ushort_t *w;
-        ssize_t mlen;
-        int pmlen;
-        mblk_t *pmp;
-        dblk_t *dp = mp->b_datap;
-        ushort_t psum = 0;
+        const uint_t sum_mask[2] = { 0, UINT_MAX };
+        uint64_t sum = initial_sum;
+        uint64_t total_len = 0;
+        uchar_t *w;
+        size_t mlen = MBLKL(mp);
+        uint_t msum, mask;
 
-#ifdef ZC_TEST
-        if (noswcksum)
-                return (0xffff);
-#endif
-        ASSERT(dp);
+        ASSERT3S(offset, >=, 0);
 
-        if (mp->b_cont == NULL) {
-                /*
-                 * May be fast-path, only one mblk.
-                 */
-                w = (ushort_t *)(mp->b_rptr + offset);
-                if (dp->db_struioflag & STRUIO_IP) {
-                        /*
-                         * Checksum any data not already done by
-                         * the caller and add in any partial checksum.
-                         */
-                        if ((offset > dp->db_cksumstart) ||
-                            mp->b_wptr != (uchar_t *)(mp->b_rptr +
-                            dp->db_cksumend)) {
-                                /*
-                                 * Mblk data pointers aren't inclusive
-                                 * of uio data, so disregard checksum.
-                                 *
-                                 * not using all of data in dblk make sure
-                                 * not use to use the precalculated checksum
-                                 * in this case.
-                                 */
-                                dp->db_struioflag &= ~STRUIO_IP;
-                                goto norm;
+        VERIFY(!HAS_UIOSUM(mp));
+        while (UNLIKELY(offset > mlen)) {
+                ASSERT3P(mp->b_cont, !=, NULL);
+                mp = mp->b_cont;
+                VERIFY(!HAS_UIOSUM(mp));
+                offset -= mlen;
+                mlen = MBLKL(mp);
                         }
-                        ASSERT(mp->b_wptr == (mp->b_rptr + dp->db_cksumend));
-                        psum = *(ushort_t *)dp->db_struioun.data;
-                        if ((mlen = dp->db_cksumstart - offset) < 0)
-                                mlen = 0;
-                        if (is_odd(mlen))
-                                goto slow;
-                        if (mlen && dp->db_cksumstart != dp->db_cksumstuff &&
-                            dp->db_cksumend != dp->db_cksumstuff) {
+
                                 /*
-                                 * There is prefix data to do and some uio
-                                 * data has already been checksumed and there
-                                 * is more uio data to do, so do the prefix
-                                 * data first, then do the remainder of the
-                                 * uio data.
+         * Make sure we start with a folded sum.  Since the initial sum
+         * is 32 bits, folding twice will always produce a sum <= 0xFFFF
                                  */
-                                sum = ip_ocsum(w, mlen >> 1, sum);
-                                w = (ushort_t *)(mp->b_rptr +
-                                    dp->db_cksumstuff);
-                                if (is_odd(w)) {
-                                        pmp = mp;
-                                        goto slow1;
+        sum = FOLD(sum);
+        sum = FOLD(sum);
+        ASSERT3U(sum, <=, 0xFFFF);
+
+        while (mp != NULL) {
+                w = mp->b_rptr + offset;
+                mlen = mp->b_wptr - w;
+                offset = 0;
+
+                ASSERT3P(w, <=, mp->b_wptr);
+                VERIFY(!HAS_UIOSUM(mp));
+
+                if (UNLIKELY(mlen == 0)) {
+                        mp = mp->b_cont;
+                        continue;
                                 }
-                                mlen = dp->db_cksumend - dp->db_cksumstuff;
-                        } else if (dp->db_cksumend != dp->db_cksumstuff) {
-                                /*
-                                 * There may be uio data to do, if there is
-                                 * prefix data to do then add in all of the
-                                 * uio data (if any) to do, else just do any
-                                 * uio data.
-                                 */
-                                if (mlen)
-                                        mlen += dp->db_cksumend
-                                            - dp->db_cksumstuff;
-                                else {
-                                        w = (ushort_t *)(mp->b_rptr +
-                                            dp->db_cksumstuff);
-                                        if (is_odd(w))
-                                                goto slow;
-                                        mlen = dp->db_cksumend
-                                            - dp->db_cksumstuff;
-                                }
-                        } else if (mlen == 0)
-                                return (psum);
 
-                        if (is_odd(mlen))
-                                goto slow;
-                        sum += psum;
-                } else {
                         /*
-                         * Checksum all data not already done by the caller.
+                 * ip_ocsum() currently requires a 16-bit aligned address.
+                 * For unaligned buffers, we first sum the odd byte (and
+                 * fold if necessary) before calling ip_ocsum(). ip_ocsum()
+                 * also takes its length in units of 16-bit words.  If
+                 * we have an odd length, we must also manually add it after
+                 * computing the main sum (and again fold if necessary).
+                 *
+                 * Since ip_ocsum() _should_ be a private per-platform
+                 * optimized ip cksum implementation (with ip_cksum() being
+                 * the less-private wrapper around it), a nice future
+                 * optimization could be to modify ip_ocsum() for each
+                 * platform to take a 64-bit sum.  This would allow us to
+                 * only have to fold exactly once before we return --
+                 * the amount of data we'd need to checksum to overflow 64
+                 * bits far exceeds the possible size of any mblk chain we
+                 * could ever have.
                          */
-                norm:
-                        mlen = mp->b_wptr - (uchar_t *)w;
-                        if (is_odd(mlen))
-                                goto slow;
+                if (UNLIKELY(IS_ODD(w))) {
+                        sum += FRAG(w);
+                        w++;
+
+                        --mlen;
+                        total_len++;
+
+                        if (UNLIKELY(mlen == 0)) {
+                                mp = mp->b_cont;
+                                continue;
                 }
-                ASSERT(is_even(w));
-                ASSERT(is_even(mlen));
-                return (ip_ocsum(w, mlen >> 1, sum));
         }
-        if (dp->db_struioflag & STRUIO_IP)
-                psum = *(ushort_t *)dp->db_struioun.data;
-slow:
-        pmp = 0;
-slow1:
-        mlen = 0;
-        pmlen = 0;
-        for (; ; ) {
+
                 /*
-                 * Each trip around loop adds in word(s) from one mbuf segment
-                 * (except for when pmp == mp, then its two partial trips).
+                 * ip_ocsum() takes the length as the number of half words
+                 * (i.e. uint16_ts). It returns a result that is already
+                 * folded (<= 0xFFFF).
                  */
-                w = (ushort_t *)(mp->b_rptr + offset);
-                if (pmp) {
+                msum = ip_ocsum((ushort_t *)w, mlen / 2, 0);
+                ASSERT3U(msum, <=, 0xFFFF);
+
                         /*
-                         * This is the second trip around for this mblk.
+                 * We mask the last byte based on the length of data.
+                 * If the length is odd, we AND with UINT_MAX otherwise
+                 * we AND with 0 (resulting in 0) and add the result to
+                 * the mblk_t sum. This effectively gives us:
+                 *
+                 * if (IS_ODD(mlen))
+                 *      msum += FRAG(w + mlen - 1);
+                 * else
+                 *      msum += 0;
+                 *
+                 * Without incurring a branch.
                          */
-                        pmp = 0;
-                        mlen = 0;
-                        goto douio;
-                } else if (dp->db_struioflag & STRUIO_IP) {
+                mask = sum_mask[IS_ODD(mlen)];
+                msum += FRAG(w + mlen - 1) & mask;
+
                         /*
-                         * Checksum any data not already done by the
-                         * caller and add in any partial checksum.
-                         */
-                        if ((offset > dp->db_cksumstart) ||
-                            mp->b_wptr != (uchar_t *)(mp->b_rptr +
-                            dp->db_cksumend)) {
-                                /*
-                                 * Mblk data pointers aren't inclusive
-                                 * of uio data, so disregard checksum.
+                 * If the data we are checksumming has been split
+                 * between two mblk_ts along a non-16 bit boundary, that is
+                 * we have something like:
+                 *      mblk_t 1: aa bb cc
+                 *      mblk_t 2: dd ee ff
+                 * the result must be the same as if we checksummed a
+                 * single mblk_t with 'aa bb cc dd ee ff'. As can be seen
+                 * from the example, this situation causes the grouping of
+                 * the data in the second mblk_t to be offset by a byte.
+                 * The fix is to byteswap the mblk_t sum before adding it
+                 * to the final sum. Again, we AND the mblk_t sum with a mask
+                 * so that either the non-swapped or byteswapped sum is zeroed
+                 * out and the other one is preserved (depending on the
+                 * total bytes checksummed so far) and added to the sum.
                                  *
-                                 * not using all of data in dblk make sure
-                                 * not use to use the precalculated checksum
-                                 * in this case.
+                 * Effectively,
+                 *
+                 * if (IS_ODD(total_len))
+                 *      sum += BSWAP_32(msum);
+                 * else
+                 *      sum += msum;
                                  */
-                                dp->db_struioflag &= ~STRUIO_IP;
-                                goto snorm;
-                        }
-                        ASSERT(mp->b_wptr == (mp->b_rptr + dp->db_cksumend));
-                        if ((mlen = dp->db_cksumstart - offset) < 0)
-                                mlen = 0;
-                        if (mlen && dp->db_cksumstart != dp->db_cksumstuff) {
-                                /*
-                                 * There is prefix data too do and some
-                                 * uio data has already been checksumed,
-                                 * so do the prefix data only this trip.
-                                 */
-                                pmp = mp;
-                        } else {
-                                /*
-                                 * Add in any partial cksum (if any) and
-                                 * do the remainder of the uio data.
-                                 */
-                                int odd;
-                        douio:
-                                odd = is_odd(dp->db_cksumstuff -
-                                    dp->db_cksumstart);
-                                if (pmlen == -1) {
-                                        /*
-                                         * Previous mlen was odd, so swap
-                                         * the partial checksum bytes.
-                                         */
-                                        sum += ((psum << 8) & 0xffff)
-                                            | (psum >> 8);
-                                        if (odd)
-                                                pmlen = 0;
-                                } else {
-                                        sum += psum;
-                                        if (odd)
-                                                pmlen = -1;
-                                }
-                                if (dp->db_cksumend != dp->db_cksumstuff) {
-                                        /*
-                                         * If prefix data to do and then all
-                                         * the uio data nees to be checksumed,
-                                         * else just do any uio data.
-                                         */
-                                        if (mlen)
-                                                mlen += dp->db_cksumend
-                                                    - dp->db_cksumstuff;
-                                        else {
-                                                w = (ushort_t *)(mp->b_rptr +
-                                                    dp->db_cksumstuff);
-                                                mlen = dp->db_cksumend -
-                                                    dp->db_cksumstuff;
-                                        }
-                                }
-                        }
-                } else {
-                        /*
-                         * Checksum all of the mblk data.
-                         */
-                snorm:
-                        mlen = mp->b_wptr - (uchar_t *)w;
-                }
+                mask = sum_mask[IS_ODD(total_len)];
+                sum += BSWAP_32(msum) & mask;
+                sum += msum & ~mask;
 
+                total_len += mlen;
                 mp = mp->b_cont;
-                if (mlen > 0 && pmlen == -1) {
-                        /*
-                         * There is a byte left from the last
-                         * segment; add it into the checksum.
-                         * Don't have to worry about a carry-
-                         * out here because we make sure that
-                         * high part of (32 bit) sum is small
-                         * below.
-                         */
-#ifdef _LITTLE_ENDIAN
-                        sum += *(uchar_t *)w << 8;
-#else
-                        sum += *(uchar_t *)w;
-#endif
-                        w = (ushort_t *)((char *)w + 1);
-                        mlen--;
-                        pmlen = 0;
                 }
-                if (mlen > 0) {
-                        if (is_even(w)) {
-                                sum = ip_ocsum(w, mlen>>1, sum);
-                                w += mlen>>1;
+
                                 /*
-                                 * If we had an odd number of bytes,
-                                 * then the last byte goes in the high
-                                 * part of the sum, and we take the
-                                 * first byte to the low part of the sum
-                                 * the next time around the loop.
+         * To avoid unnecessary folding, we store the cumulative sum in
+         * a uint64_t. This means we can always checksum up to 2^56 bytes
+         * (2^(64-8)) without danger of overflowing.  Since 2^56 is well past
+         * the petabyte range, and is far beyond the amount of data that
+         * could every be stored in a single mblk_t chain (for the forseeable
+         * future), this serves more as a sanity check than anything else.
                                  */
-                                if (is_odd(mlen)) {
-#ifdef _LITTLE_ENDIAN
-                                        sum += *(uchar_t *)w;
-#else
-                                        sum += *(uchar_t *)w << 8;
-#endif
-                                        pmlen = -1;
-                                }
-                        } else {
-                                ushort_t swsum;
-#ifdef _LITTLE_ENDIAN
-                                sum += *(uchar_t *)w;
-#else
-                                sum += *(uchar_t *)w << 8;
-#endif
-                                mlen--;
-                                w = (ushort_t *)(1 + (uintptr_t)w);
+        VERIFY3U(total_len, <=, (uint64_t)1 << 56);
 
-                                /* Do a separate checksum and copy operation */
-                                swsum = ip_ocsum(w, mlen>>1, 0);
-                                sum += ((swsum << 8) & 0xffff) | (swsum >> 8);
-                                w += mlen>>1;
                                 /*
-                                 * If we had an even number of bytes,
-                                 * then the last byte goes in the low
-                                 * part of the sum.  Otherwise we had an
-                                 * odd number of bytes and we take the first
-                                 * byte to the low part of the sum the
-                                 * next time around the loop.
+         * For a 64-bit sum, we have to fold at most 4 times to
+         * produce a sum <= 0xFFFF.
                                  */
-                                if (is_odd(mlen)) {
-#ifdef _LITTLE_ENDIAN
-                                        sum += *(uchar_t *)w << 8;
-#else
-                                        sum += *(uchar_t *)w;
-#endif
-                                }
-                                else
-                                        pmlen = -1;
-                        }
-                }
-                /*
-                 * Locate the next block with some data.
-                 * If there is a word split across a boundary we
-                 * will wrap to the top with mlen == -1 and
-                 * then add it in shifted appropriately.
-                 */
-                offset = 0;
-                if (! pmp) {
-                        for (; ; ) {
-                                if (mp == 0) {
-                                        goto done;
-                                }
-                                if (mp_len(mp))
-                                        break;
-                                mp = mp->b_cont;
-                        }
-                        dp = mp->b_datap;
-                        if (dp->db_struioflag & STRUIO_IP)
-                                psum = *(ushort_t *)dp->db_struioun.data;
-                } else
-                        mp = pmp;
-        }
-done:
-        /*
-         * Add together high and low parts of sum
-         * and carry to get cksum.
-         * Have to be careful to not drop the last
-         * carry here.
-         */
-        sum = (sum & 0xFFFF) + (sum >> 16);
-        sum = (sum & 0xFFFF) + (sum >> 16);
+        sum = FOLD(sum);
+        sum = FOLD(sum);
+        sum = FOLD(sum);
+        sum = FOLD(sum);
+
         TRACE_3(TR_FAC_IP, TR_IP_CKSUM_END,
             "ip_cksum_end:(%S) type %d (%X)", "ip_cksum", 1, sum);
-        return (sum);
+        return ((unsigned int)sum);
 }
 
 uint32_t
 sctp_cksum(mblk_t *mp, int offset)
 {

@@ -466,21 +318,21 @@
                 }
 
                 if (mlen == 0)
                         continue;
 
-                if (is_even(w)) {
+                if (IS_EVEN(w)) {
                         sum = ip_ocsum(w, mlen >> 1, sum);
                         w += mlen >> 1;
                         /*
                          * If we had an odd number of bytes,
                          * then the last byte goes in the high
                          * part of the sum, and we take the
                          * first byte to the low part of the sum
                          * the next time around the loop.
                          */
-                        if (is_odd(mlen)) {
+                        if (IS_ODD(mlen)) {
 #ifdef _LITTLE_ENDIAN
                                 sum += *(uchar_t *)w;
 #else
                                 sum += *(uchar_t *)w << 8;
 #endif

@@ -506,11 +358,11 @@
                          * part of the sum.  Otherwise we had an
                          * odd number of bytes and we take the first
                          * byte to the low part of the sum the
                          * next time around the loop.
                          */
-                        if (is_odd(mlen)) {
+                        if (IS_ODD(mlen)) {
 #ifdef _LITTLE_ENDIAN
                                 sum += *(uchar_t *)w << 8;
 #else
                                 sum += *(uchar_t *)w;
 #endif