1 /*
   2  * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
   3  *
   4  * Permission is hereby granted, free of charge, to any person obtaining a copy
   5  * of this software and associated documentation files (the "Software"), to deal
   6  * in the Software without restriction, including without limitation the rights
   7  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
   8  * copies of the Software, and to permit persons to whom the Software is
   9  * furnished to do so, subject to the following conditions:
  10  *
  11  * The above copyright notice and this permission notice shall be included in
  12  * all copies or substantial portions of the Software.
  13  *
  14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  17  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  19  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  20  * THE SOFTWARE.
  21  */
  22 
  23 #include <assert.h>
  24 #include <stdlib.h>
  25 
  26 #include "smatch.h"
  27 #include "smatch_slist.h"
  28 
  29 static AvlNode *mkNode(const struct sm_state *sm);
  30 static void freeNode(AvlNode *node);
  31 
  32 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm);
  33 
  34 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm);
  35 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret);
  36 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
  37 
  38 static int sway(AvlNode **p, int sway);
  39 static void balance(AvlNode **p, int side);
  40 
  41 static bool checkBalances(AvlNode *node, int *height);
  42 static bool checkOrder(struct stree *avl);
  43 static size_t countNode(AvlNode *node);
  44 
  45 int unfree_stree;
  46 
  47 /*
  48  * Utility macros for converting between
  49  * "balance" values (-1 or 1) and "side" values (0 or 1).
  50  *
  51  * bal(0)   == -1
  52  * bal(1)   == +1
  53  * side(-1) == 0
  54  * side(+1) == 1
  55  */
  56 #define bal(side) ((side) == 0 ? -1 : 1)
  57 #define side(bal) ((bal)  == 1 ?  1 : 0)
  58 
  59 static struct stree *avl_new(void)
  60 {
  61         struct stree *avl = malloc(sizeof(*avl));
  62 
  63         unfree_stree++;
  64         assert(avl != NULL);
  65 
  66         avl->root = NULL;
  67         avl->base_stree = NULL;
  68         avl->has_states = calloc(num_checks + 1, sizeof(char));
  69         avl->count = 0;
  70         avl->stree_id = 0;
  71         avl->references = 1;
  72         return avl;
  73 }
  74 
  75 void free_stree(struct stree **avl)
  76 {
  77         if (!*avl)
  78                 return;
  79 
  80         assert((*avl)->references > 0);
  81 
  82         (*avl)->references--;
  83         if ((*avl)->references != 0) {
  84                 *avl = NULL;
  85                 return;
  86         }
  87 
  88         unfree_stree--;
  89 
  90         freeNode((*avl)->root);
  91         free(*avl);
  92         *avl = NULL;
  93 }
  94 
  95 struct sm_state *avl_lookup(const struct stree *avl, const struct sm_state *sm)
  96 {
  97         AvlNode *found;
  98 
  99         if (!avl)
 100                 return NULL;
 101         if (sm->owner != USHRT_MAX &&
 102             !avl->has_states[sm->owner])
 103                 return NULL;
 104         found = lookup(avl, avl->root, sm);
 105         if (!found)
 106                 return NULL;
 107         return (struct sm_state *)found->sm;
 108 }
 109 
 110 AvlNode *avl_lookup_node(const struct stree *avl, const struct sm_state *sm)
 111 {
 112         return lookup(avl, avl->root, sm);
 113 }
 114 
 115 size_t stree_count(const struct stree *avl)
 116 {
 117         if (!avl)
 118                 return 0;
 119         return avl->count;
 120 }
 121 
 122 static struct stree *clone_stree_real(struct stree *orig)
 123 {
 124         struct stree *new = avl_new();
 125         AvlIter i;
 126 
 127         avl_foreach(i, orig)
 128                 avl_insert(&new, i.sm);
 129 
 130         new->base_stree = orig->base_stree;
 131         return new;
 132 }
 133 
 134 bool avl_insert(struct stree **avl, const struct sm_state *sm)
 135 {
 136         size_t old_count;
 137 
 138         if (!*avl)
 139                 *avl = avl_new();
 140         if ((*avl)->references > 1) {
 141                 (*avl)->references--;
 142                 *avl = clone_stree_real(*avl);
 143         }
 144         old_count = (*avl)->count;
 145         /* fortunately we never call get_state() on "unnull_path" */
 146         if (sm->owner != USHRT_MAX)
 147                 (*avl)->has_states[sm->owner] = 1;
 148         insert_sm(*avl, &(*avl)->root, sm);
 149         return (*avl)->count != old_count;
 150 }
 151 
 152 bool avl_remove(struct stree **avl, const struct sm_state *sm)
 153 {
 154         AvlNode *node = NULL;
 155 
 156         if (!*avl)
 157                 return false;
 158         /* it's fairly rare for smatch to call avl_remove */
 159         if ((*avl)->references > 1) {
 160                 (*avl)->references--;
 161                 *avl = clone_stree_real(*avl);
 162         }
 163 
 164         remove_sm(*avl, &(*avl)->root, sm, &node);
 165 
 166         if ((*avl)->count == 0)
 167                 free_stree(avl);
 168 
 169         if (node == NULL) {
 170                 return false;
 171         } else {
 172                 free(node);
 173                 return true;
 174         }
 175 }
 176 
 177 static AvlNode *mkNode(const struct sm_state *sm)
 178 {
 179         AvlNode *node = malloc(sizeof(*node));
 180 
 181         assert(node != NULL);
 182 
 183         node->sm = sm;
 184         node->lr[0] = NULL;
 185         node->lr[1] = NULL;
 186         node->balance = 0;
 187         return node;
 188 }
 189 
 190 static void freeNode(AvlNode *node)
 191 {
 192         if (node) {
 193                 freeNode(node->lr[0]);
 194                 freeNode(node->lr[1]);
 195                 free(node);
 196         }
 197 }
 198 
 199 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm)
 200 {
 201         int cmp;
 202 
 203         if (node == NULL)
 204                 return NULL;
 205 
 206         cmp = cmp_tracker(sm, node->sm);
 207 
 208         if (cmp < 0)
 209                 return lookup(avl, node->lr[0], sm);
 210         if (cmp > 0)
 211                 return lookup(avl, node->lr[1], sm);
 212         return node;
 213 }
 214 
 215 /*
 216  * Insert an sm into a subtree, rebalancing if necessary.
 217  *
 218  * Return true if the subtree's height increased.
 219  */
 220 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm)
 221 {
 222         if (*p == NULL) {
 223                 *p = mkNode(sm);
 224                 avl->count++;
 225                 return true;
 226         } else {
 227                 AvlNode *node = *p;
 228                 int      cmp  = cmp_tracker(sm, node->sm);
 229 
 230                 if (cmp == 0) {
 231                         node->sm = sm;
 232                         return false;
 233                 }
 234 
 235                 if (!insert_sm(avl, &node->lr[side(cmp)], sm))
 236                         return false;
 237 
 238                 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
 239                 return sway(p, cmp) != 0;
 240         }
 241 }
 242 
 243 /*
 244  * Remove the node matching the given sm.
 245  * If present, return the removed node through *ret .
 246  * The returned node's lr and balance are meaningless.
 247  *
 248  * Return true if the subtree's height decreased.
 249  */
 250 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret)
 251 {
 252         if (p == NULL || *p == NULL) {
 253                 return false;
 254         } else {
 255                 AvlNode *node = *p;
 256                 int      cmp  = cmp_tracker(sm, node->sm);
 257 
 258                 if (cmp == 0) {
 259                         *ret = node;
 260                         avl->count--;
 261 
 262                         if (node->lr[0] != NULL && node->lr[1] != NULL) {
 263                                 AvlNode *replacement;
 264                                 int      side;
 265                                 bool     shrunk;
 266 
 267                                 /* Pick a subtree to pull the replacement from such that
 268                                  * this node doesn't have to be rebalanced. */
 269                                 side = node->balance <= 0 ? 0 : 1;
 270 
 271                                 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
 272 
 273                                 replacement->lr[0]   = node->lr[0];
 274                                 replacement->lr[1]   = node->lr[1];
 275                                 replacement->balance = node->balance;
 276                                 *p = replacement;
 277 
 278                                 if (!shrunk)
 279                                         return false;
 280 
 281                                 replacement->balance -= bal(side);
 282 
 283                                 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
 284                                 return replacement->balance == 0;
 285                         }
 286 
 287                         if (node->lr[0] != NULL)
 288                                 *p = node->lr[0];
 289                         else
 290                                 *p = node->lr[1];
 291 
 292                         return true;
 293 
 294                 } else {
 295                         if (!remove_sm(avl, &node->lr[side(cmp)], sm, ret))
 296                                 return false;
 297 
 298                         /* If tree's balance became 0, it means the tree's height shrank due to removal. */
 299                         return sway(p, -cmp) == 0;
 300                 }
 301         }
 302 }
 303 
 304 /*
 305  * Remove either the left-most (if side == 0) or right-most (if side == 1)
 306  * node in a subtree, returning the removed node through *ret .
 307  * The returned node's lr and balance are meaningless.
 308  *
 309  * The subtree must not be empty (i.e. *p must not be NULL).
 310  *
 311  * Return true if the subtree's height decreased.
 312  */
 313 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
 314 {
 315         AvlNode *node = *p;
 316 
 317         if (node->lr[side] == NULL) {
 318                 *ret = node;
 319                 *p = node->lr[1 - side];
 320                 return true;
 321         }
 322 
 323         if (!removeExtremum(&node->lr[side], side, ret))
 324                 return false;
 325 
 326         /* If tree's balance became 0, it means the tree's height shrank due to removal. */
 327         return sway(p, -bal(side)) == 0;
 328 }
 329 
 330 /*
 331  * Rebalance a node if necessary.  Think of this function
 332  * as a higher-level interface to balance().
 333  *
 334  * sway must be either -1 or 1, and indicates what was added to
 335  * the balance of this node by a prior operation.
 336  *
 337  * Return the new balance of the subtree.
 338  */
 339 static int sway(AvlNode **p, int sway)
 340 {
 341         if ((*p)->balance != sway)
 342                 (*p)->balance += sway;
 343         else
 344                 balance(p, side(sway));
 345 
 346         return (*p)->balance;
 347 }
 348 
 349 /*
 350  * Perform tree rotations on an unbalanced node.
 351  *
 352  * side == 0 means the node's balance is -2 .
 353  * side == 1 means the node's balance is +2 .
 354  */
 355 static void balance(AvlNode **p, int side)
 356 {
 357         AvlNode  *node  = *p,
 358                  *child = node->lr[side];
 359         int opposite    = 1 - side;
 360         int bal         = bal(side);
 361 
 362         if (child->balance != -bal) {
 363                 /* Left-left (side == 0) or right-right (side == 1) */
 364                 node->lr[side]      = child->lr[opposite];
 365                 child->lr[opposite] = node;
 366                 *p = child;
 367 
 368                 child->balance -= bal;
 369                 node->balance = -child->balance;
 370 
 371         } else {
 372                 /* Left-right (side == 0) or right-left (side == 1) */
 373                 AvlNode *grandchild = child->lr[opposite];
 374 
 375                 node->lr[side]           = grandchild->lr[opposite];
 376                 child->lr[opposite]      = grandchild->lr[side];
 377                 grandchild->lr[side]     = child;
 378                 grandchild->lr[opposite] = node;
 379                 *p = grandchild;
 380 
 381                 node->balance       = 0;
 382                 child->balance      = 0;
 383 
 384                 if (grandchild->balance == bal)
 385                         node->balance  = -bal;
 386                 else if (grandchild->balance == -bal)
 387                         child->balance = bal;
 388 
 389                 grandchild->balance = 0;
 390         }
 391 }
 392 
 393 
 394 /************************* avl_check_invariants() *************************/
 395 
 396 bool avl_check_invariants(struct stree *avl)
 397 {
 398         int    dummy;
 399 
 400         return checkBalances(avl->root, &dummy)
 401             && checkOrder(avl)
 402             && countNode(avl->root) == avl->count;
 403 }
 404 
 405 static bool checkBalances(AvlNode *node, int *height)
 406 {
 407         if (node) {
 408                 int h0, h1;
 409 
 410                 if (!checkBalances(node->lr[0], &h0))
 411                         return false;
 412                 if (!checkBalances(node->lr[1], &h1))
 413                         return false;
 414 
 415                 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
 416                         return false;
 417 
 418                 *height = (h0 > h1 ? h0 : h1) + 1;
 419                 return true;
 420         } else {
 421                 *height = 0;
 422                 return true;
 423         }
 424 }
 425 
 426 static bool checkOrder(struct stree *avl)
 427 {
 428         AvlIter     i;
 429         const struct sm_state *last = NULL;
 430         bool        last_set = false;
 431 
 432         avl_foreach(i, avl) {
 433                 if (last_set && cmp_tracker(last, i.sm) >= 0)
 434                         return false;
 435                 last     = i.sm;
 436                 last_set = true;
 437         }
 438 
 439         return true;
 440 }
 441 
 442 static size_t countNode(AvlNode *node)
 443 {
 444         if (node)
 445                 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
 446         else
 447                 return 0;
 448 }
 449 
 450 
 451 /************************* Traversal *************************/
 452 
 453 void avl_iter_begin(AvlIter *iter, struct stree *avl, AvlDirection dir)
 454 {
 455         AvlNode *node;
 456 
 457         iter->stack_index = 0;
 458         iter->direction   = dir;
 459 
 460         if (!avl || !avl->root) {
 461                 iter->sm      = NULL;
 462                 iter->node     = NULL;
 463                 return;
 464         }
 465         node = avl->root;
 466 
 467         while (node->lr[dir] != NULL) {
 468                 iter->stack[iter->stack_index++] = node;
 469                 node = node->lr[dir];
 470         }
 471 
 472         iter->sm   = (struct sm_state *) node->sm;
 473         iter->node  = node;
 474 }
 475 
 476 void avl_iter_next(AvlIter *iter)
 477 {
 478         AvlNode     *node = iter->node;
 479         AvlDirection dir  = iter->direction;
 480 
 481         if (node == NULL)
 482                 return;
 483 
 484         node = node->lr[1 - dir];
 485         if (node != NULL) {
 486                 while (node->lr[dir] != NULL) {
 487                         iter->stack[iter->stack_index++] = node;
 488                         node = node->lr[dir];
 489                 }
 490         } else if (iter->stack_index > 0) {
 491                 node = iter->stack[--iter->stack_index];
 492         } else {
 493                 iter->sm      = NULL;
 494                 iter->node     = NULL;
 495                 return;
 496         }
 497 
 498         iter->node  = node;
 499         iter->sm   = (struct sm_state *) node->sm;
 500 }
 501 
 502 struct stree *clone_stree(struct stree *orig)
 503 {
 504         if (!orig)
 505                 return NULL;
 506 
 507         orig->references++;
 508         return orig;
 509 }
 510 
 511 void set_stree_id(struct stree **stree, int stree_id)
 512 {
 513         if ((*stree)->stree_id != 0)
 514                 *stree = clone_stree_real(*stree);
 515 
 516         (*stree)->stree_id = stree_id;
 517 }
 518 
 519 int get_stree_id(struct stree *stree)
 520 {
 521         if (!stree)
 522                 return -1;
 523         return stree->stree_id;
 524 }