1 /*
   2  * Copyright (c) 2010, Oracle and/or its affiliates. All rights reserved.
   3  */
   4 
   5 /*
   6  * This file contains code imported from the OFED rds source file message.c
   7  * Oracle elects to have and use the contents of message.c under and governed
   8  * by the OpenIB.org BSD license (see below for full license text). However,
   9  * the following notice accompanied the original version of this file:
  10  */
  11 
  12 /*
  13  * Copyright (c) 2006 Oracle.  All rights reserved.
  14  *
  15  * This software is available to you under a choice of one of two
  16  * licenses.  You may choose to be licensed under the terms of the GNU
  17  * General Public License (GPL) Version 2, available from the file
  18  * COPYING in the main directory of this source tree, or the
  19  * OpenIB.org BSD license below:
  20  *
  21  *     Redistribution and use in source and binary forms, with or
  22  *     without modification, are permitted provided that the following
  23  *     conditions are met:
  24  *
  25  *      - Redistributions of source code must retain the above
  26  *        copyright notice, this list of conditions and the following
  27  *        disclaimer.
  28  *
  29  *      - Redistributions in binary form must reproduce the above
  30  *        copyright notice, this list of conditions and the following
  31  *        disclaimer in the documentation and/or other materials
  32  *        provided with the distribution.
  33  *
  34  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  35  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  36  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  37  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  38  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  39  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  40  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  41  * SOFTWARE.
  42  *
  43  */
  44 #include <sys/rds.h>
  45 
  46 #include <sys/ib/clients/rdsv3/rdsv3.h>
  47 #include <sys/ib/clients/rdsv3/rdma.h>
  48 #include <sys/ib/clients/rdsv3/rdsv3_debug.h>
  49 
  50 static unsigned int     rdsv3_exthdr_size[__RDSV3_EXTHDR_MAX] = {
  51 [RDSV3_EXTHDR_NONE]     = 0,
  52 [RDSV3_EXTHDR_VERSION]  = sizeof (struct rdsv3_ext_header_version),
  53 [RDSV3_EXTHDR_RDMA]     = sizeof (struct rdsv3_ext_header_rdma),
  54 [RDSV3_EXTHDR_RDMA_DEST]        = sizeof (struct rdsv3_ext_header_rdma_dest),
  55 };
  56 
  57 void
  58 rdsv3_message_addref(struct rdsv3_message *rm)
  59 {
  60         RDSV3_DPRINTF5("rdsv3_message_addref", "addref rm %p ref %d",
  61             rm, atomic_get(&rm->m_refcount));
  62         atomic_inc_32(&rm->m_refcount);
  63 }
  64 
  65 /*
  66  * This relies on dma_map_sg() not touching sg[].page during merging.
  67  */
  68 static void
  69 rdsv3_message_purge(struct rdsv3_message *rm)
  70 {
  71         unsigned long i;
  72 
  73         RDSV3_DPRINTF4("rdsv3_message_purge", "Enter(rm: %p)", rm);
  74 
  75         if (test_bit(RDSV3_MSG_PAGEVEC, &rm->m_flags))
  76                 return;
  77 
  78         for (i = 0; i < rm->m_nents; i++) {
  79                 RDSV3_DPRINTF5("rdsv3_message_purge", "putting data page %p\n",
  80                     (void *)rdsv3_sg_page(&rm->m_sg[i]));
  81                 /* XXX will have to put_page for page refs */
  82                 kmem_free(rdsv3_sg_page(&rm->m_sg[i]),
  83                     rdsv3_sg_len(&rm->m_sg[i]));
  84         }
  85 
  86         if (rm->m_rdma_op)
  87                 rdsv3_rdma_free_op(rm->m_rdma_op);
  88         if (rm->m_rdma_mr) {
  89                 struct rdsv3_mr *mr = rm->m_rdma_mr;
  90                 if (mr->r_refcount == 0) {
  91                         RDSV3_DPRINTF4("rdsv3_message_purge ASSERT 0",
  92                             "rm %p mr %p", rm, mr);
  93                         return;
  94                 }
  95                 if (mr->r_refcount == 0xdeadbeef) {
  96                         RDSV3_DPRINTF4("rdsv3_message_purge ASSERT deadbeef",
  97                             "rm %p mr %p", rm, mr);
  98                         return;
  99                 }
 100                 if (atomic_dec_and_test(&mr->r_refcount)) {
 101                         rm->m_rdma_mr = NULL;
 102                         __rdsv3_put_mr_final(mr);
 103                 }
 104         }
 105 
 106         RDSV3_DPRINTF4("rdsv3_message_purge", "Return(rm: %p)", rm);
 107 
 108 }
 109 
 110 void
 111 rdsv3_message_put(struct rdsv3_message *rm)
 112 {
 113         RDSV3_DPRINTF5("rdsv3_message_put",
 114             "put rm %p ref %d\n", rm, atomic_get(&rm->m_refcount));
 115 
 116         if (atomic_dec_and_test(&rm->m_refcount)) {
 117                 ASSERT(!list_link_active(&rm->m_sock_item));
 118                 ASSERT(!list_link_active(&rm->m_conn_item));
 119                 rdsv3_message_purge(rm);
 120 
 121                 kmem_free(rm, sizeof (struct rdsv3_message) +
 122                     (rm->m_nents * sizeof (struct rdsv3_scatterlist)));
 123         }
 124 }
 125 
 126 void
 127 rdsv3_message_inc_free(struct rdsv3_incoming *inc)
 128 {
 129         struct rdsv3_message *rm =
 130             container_of(inc, struct rdsv3_message, m_inc);
 131         rdsv3_message_put(rm);
 132 }
 133 
 134 void
 135 rdsv3_message_populate_header(struct rdsv3_header *hdr, uint16_be_t sport,
 136     uint16_be_t dport, uint64_t seq)
 137 {
 138         hdr->h_flags = 0;
 139         hdr->h_sport = sport;
 140         hdr->h_dport = dport;
 141         hdr->h_sequence = htonll(seq);
 142         hdr->h_exthdr[0] = RDSV3_EXTHDR_NONE;
 143 }
 144 
 145 int
 146 rdsv3_message_add_extension(struct rdsv3_header *hdr,
 147     unsigned int type, const void *data, unsigned int len)
 148 {
 149         unsigned int ext_len = sizeof (uint8_t) + len;
 150         unsigned char *dst;
 151 
 152         RDSV3_DPRINTF4("rdsv3_message_add_extension", "Enter");
 153 
 154         /* For now, refuse to add more than one extension header */
 155         if (hdr->h_exthdr[0] != RDSV3_EXTHDR_NONE)
 156                 return (0);
 157 
 158         if (type >= __RDSV3_EXTHDR_MAX ||
 159             len != rdsv3_exthdr_size[type])
 160                 return (0);
 161 
 162         if (ext_len >= RDSV3_HEADER_EXT_SPACE)
 163                 return (0);
 164         dst = hdr->h_exthdr;
 165 
 166         *dst++ = type;
 167         (void) memcpy(dst, data, len);
 168 
 169         dst[len] = RDSV3_EXTHDR_NONE;
 170 
 171         RDSV3_DPRINTF4("rdsv3_message_add_extension", "Return");
 172         return (1);
 173 }
 174 
 175 /*
 176  * If a message has extension headers, retrieve them here.
 177  * Call like this:
 178  *
 179  * unsigned int pos = 0;
 180  *
 181  * while (1) {
 182  *      buflen = sizeof(buffer);
 183  *      type = rdsv3_message_next_extension(hdr, &pos, buffer, &buflen);
 184  *      if (type == RDSV3_EXTHDR_NONE)
 185  *              break;
 186  *      ...
 187  * }
 188  */
 189 int
 190 rdsv3_message_next_extension(struct rdsv3_header *hdr,
 191     unsigned int *pos, void *buf, unsigned int *buflen)
 192 {
 193         unsigned int offset, ext_type, ext_len;
 194         uint8_t *src = hdr->h_exthdr;
 195 
 196         RDSV3_DPRINTF4("rdsv3_message_next_extension", "Enter");
 197 
 198         offset = *pos;
 199         if (offset >= RDSV3_HEADER_EXT_SPACE)
 200                 goto none;
 201 
 202         /*
 203          * Get the extension type and length. For now, the
 204          * length is implied by the extension type.
 205          */
 206         ext_type = src[offset++];
 207 
 208         if (ext_type == RDSV3_EXTHDR_NONE || ext_type >= __RDSV3_EXTHDR_MAX)
 209                 goto none;
 210         ext_len = rdsv3_exthdr_size[ext_type];
 211         if (offset + ext_len > RDSV3_HEADER_EXT_SPACE)
 212                 goto none;
 213 
 214         *pos = offset + ext_len;
 215         if (ext_len < *buflen)
 216                 *buflen = ext_len;
 217         (void) memcpy(buf, src + offset, *buflen);
 218         return (ext_type);
 219 
 220 none:
 221         *pos = RDSV3_HEADER_EXT_SPACE;
 222         *buflen = 0;
 223         return (RDSV3_EXTHDR_NONE);
 224 }
 225 
 226 int
 227 rdsv3_message_add_version_extension(struct rdsv3_header *hdr,
 228     unsigned int version)
 229 {
 230         struct rdsv3_ext_header_version ext_hdr;
 231 
 232         ext_hdr.h_version = htonl(version);
 233         return (rdsv3_message_add_extension(hdr, RDSV3_EXTHDR_VERSION,
 234             &ext_hdr, sizeof (ext_hdr)));
 235 }
 236 
 237 int
 238 rdsv3_message_get_version_extension(struct rdsv3_header *hdr,
 239     unsigned int *version)
 240 {
 241         struct rdsv3_ext_header_version ext_hdr;
 242         unsigned int pos = 0, len = sizeof (ext_hdr);
 243 
 244         RDSV3_DPRINTF4("rdsv3_message_get_version_extension", "Enter");
 245 
 246         /*
 247          * We assume the version extension is the only one present
 248          */
 249         if (rdsv3_message_next_extension(hdr, &pos, &ext_hdr, &len) !=
 250             RDSV3_EXTHDR_VERSION)
 251                 return (0);
 252         *version = ntohl(ext_hdr.h_version);
 253         return (1);
 254 }
 255 
 256 int
 257 rdsv3_message_add_rdma_dest_extension(struct rdsv3_header *hdr, uint32_t r_key,
 258     uint32_t offset)
 259 {
 260         struct rdsv3_ext_header_rdma_dest ext_hdr;
 261 
 262         ext_hdr.h_rdma_rkey = htonl(r_key);
 263         ext_hdr.h_rdma_offset = htonl(offset);
 264         return (rdsv3_message_add_extension(hdr, RDSV3_EXTHDR_RDMA_DEST,
 265             &ext_hdr, sizeof (ext_hdr)));
 266 }
 267 
 268 struct rdsv3_message *
 269 rdsv3_message_alloc(unsigned int nents, int gfp)
 270 {
 271         struct rdsv3_message *rm;
 272 
 273         RDSV3_DPRINTF4("rdsv3_message_alloc", "Enter(nents: %d)", nents);
 274 
 275         rm = kmem_zalloc(sizeof (struct rdsv3_message) +
 276             (nents * sizeof (struct rdsv3_scatterlist)), gfp);
 277         if (!rm)
 278                 goto out;
 279 
 280         rm->m_refcount = 1;
 281         list_link_init(&rm->m_sock_item);
 282         list_link_init(&rm->m_conn_item);
 283         mutex_init(&rm->m_rs_lock, NULL, MUTEX_DRIVER, NULL);
 284         rdsv3_init_waitqueue(&rm->m_flush_wait);
 285 
 286         RDSV3_DPRINTF4("rdsv3_message_alloc", "Return(rm: %p)", rm);
 287 out:
 288         return (rm);
 289 }
 290 
 291 struct rdsv3_message *
 292 rdsv3_message_map_pages(unsigned long *page_addrs, unsigned int total_len)
 293 {
 294         struct rdsv3_message *rm;
 295         unsigned int i;
 296 
 297         RDSV3_DPRINTF4("rdsv3_message_map_pages", "Enter(len: %d)", total_len);
 298 
 299         rm = rdsv3_message_alloc(ceil(total_len, PAGE_SIZE), KM_NOSLEEP);
 300         if (rm == NULL)
 301                 return (ERR_PTR(-ENOMEM));
 302 
 303         set_bit(RDSV3_MSG_PAGEVEC, &rm->m_flags);
 304         rm->m_inc.i_hdr.h_len = htonl(total_len);
 305         rm->m_nents = ceil(total_len, PAGE_SIZE);
 306         for (i = 0; i < rm->m_nents; ++i) {
 307                 rdsv3_sg_set_page(&rm->m_sg[i],
 308                     page_addrs[i],
 309                     PAGE_SIZE, 0);
 310         }
 311 
 312         return (rm);
 313 }
 314 
 315 struct rdsv3_message *
 316 rdsv3_message_copy_from_user(struct uio *uiop,
 317     size_t total_len)
 318 {
 319         struct rdsv3_message *rm;
 320         struct rdsv3_scatterlist *sg;
 321         int ret;
 322 
 323         RDSV3_DPRINTF4("rdsv3_message_copy_from_user", "Enter: %d", total_len);
 324 
 325         rm = rdsv3_message_alloc(ceil(total_len, PAGE_SIZE), KM_NOSLEEP);
 326         if (rm == NULL) {
 327                 ret = -ENOMEM;
 328                 goto out;
 329         }
 330 
 331         rm->m_inc.i_hdr.h_len = htonl(total_len);
 332 
 333         /*
 334          * now allocate and copy in the data payload.
 335          */
 336         sg = rm->m_sg;
 337 
 338         while (total_len) {
 339                 if (rdsv3_sg_page(sg) == NULL) {
 340                         ret = rdsv3_page_remainder_alloc(sg, total_len, 0);
 341                         if (ret)
 342                                 goto out;
 343                         rm->m_nents++;
 344                 }
 345 
 346                 ret = uiomove(rdsv3_sg_page(sg), rdsv3_sg_len(sg), UIO_WRITE,
 347                     uiop);
 348                 if (ret) {
 349                         RDSV3_DPRINTF2("rdsv3_message_copy_from_user",
 350                             "uiomove failed");
 351                         ret = -ret;
 352                         goto out;
 353                 }
 354 
 355                 total_len -= rdsv3_sg_len(sg);
 356                 sg++;
 357         }
 358         ret = 0;
 359 out:
 360         if (ret) {
 361                 if (rm)
 362                         rdsv3_message_put(rm);
 363                 rm = ERR_PTR(ret);
 364         }
 365         return (rm);
 366 }
 367 
 368 int
 369 rdsv3_message_inc_copy_to_user(struct rdsv3_incoming *inc,
 370     uio_t *uiop, size_t size)
 371 {
 372         struct rdsv3_message *rm;
 373         struct rdsv3_scatterlist *sg;
 374         unsigned long to_copy;
 375         unsigned long vec_off;
 376         int copied;
 377         int ret;
 378         uint32_t len;
 379 
 380         rm = container_of(inc, struct rdsv3_message, m_inc);
 381         len = ntohl(rm->m_inc.i_hdr.h_len);
 382 
 383         RDSV3_DPRINTF4("rdsv3_message_inc_copy_to_user",
 384             "Enter(rm: %p, len: %d)", rm, len);
 385 
 386         sg = rm->m_sg;
 387         vec_off = 0;
 388         copied = 0;
 389 
 390         while (copied < size && copied < len) {
 391 
 392                 to_copy = min(len - copied, sg->length - vec_off);
 393                 to_copy = min(size - copied, to_copy);
 394 
 395                 RDSV3_DPRINTF5("rdsv3_message_inc_copy_to_user",
 396                     "copying %lu bytes to user iov %p from sg [%p, %u] + %lu\n",
 397                     to_copy, uiop,
 398                     rdsv3_sg_page(sg), sg->length, vec_off);
 399 
 400                 ret = uiomove(rdsv3_sg_page(sg), to_copy, UIO_READ, uiop);
 401                 if (ret)
 402                         break;
 403 
 404                 vec_off += to_copy;
 405                 copied += to_copy;
 406 
 407                 if (vec_off == sg->length) {
 408                         vec_off = 0;
 409                         sg++;
 410                 }
 411         }
 412 
 413         return (copied);
 414 }
 415 
 416 /*
 417  * If the message is still on the send queue, wait until the transport
 418  * is done with it. This is particularly important for RDMA operations.
 419  */
 420 /* ARGSUSED */
 421 void
 422 rdsv3_message_wait(struct rdsv3_message *rm)
 423 {
 424         rdsv3_wait_event(&rm->m_flush_wait,
 425             !test_bit(RDSV3_MSG_MAPPED, &rm->m_flags));
 426 }
 427 
 428 void
 429 rdsv3_message_unmapped(struct rdsv3_message *rm)
 430 {
 431         clear_bit(RDSV3_MSG_MAPPED, &rm->m_flags);
 432         rdsv3_wake_up_all(&rm->m_flush_wait);
 433 }