1 /*
   2  * Copyright (C) 2017 Oracle.
   3  *
   4  * This program is free software; you can redistribute it and/or
   5  * modify it under the terms of the GNU General Public License
   6  * as published by the Free Software Foundation; either version 2
   7  * of the License, or (at your option) any later version.
   8  *
   9  * This program is distributed in the hope that it will be useful,
  10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  12  * GNU General Public License for more details.
  13  *
  14  * You should have received a copy of the GNU General Public License
  15  * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
  16  */
  17 
  18 #include "smatch.h"
  19 #include "smatch_slist.h"
  20 #include "smatch_extra.h"
  21 
  22 static int my_id;
  23 static int param_id;
  24 
  25 static struct stree *used_stree;
  26 static struct stree_stack *saved_stack;
  27 
  28 STATE(used);
  29 
  30 int get_param_from_container_of(struct expression *expr)
  31 {
  32         struct expression *param_expr;
  33         struct symbol *type;
  34         sval_t sval;
  35         int param;
  36 
  37 
  38         type = get_type(expr);
  39         if (!type || type->type != SYM_PTR)
  40                 return -1;
  41 
  42         expr = strip_expr(expr);
  43         if (expr->type != EXPR_BINOP || expr->op != '-')
  44                 return -1;
  45 
  46         if (!get_value(expr->right, &sval))
  47                 return -1;
  48         if (sval.value < 0 || sval.value > 4096)
  49                 return -1;
  50 
  51         param_expr = get_assigned_expr(expr->left);
  52         if (!param_expr)
  53                 return -1;
  54         param = get_param_num(param_expr);
  55         if (param < 0)
  56                 return -1;
  57 
  58         return param;
  59 }
  60 
  61 int get_offset_from_container_of(struct expression *expr)
  62 {
  63         struct expression *param_expr;
  64         struct symbol *type;
  65         sval_t sval;
  66 
  67         type = get_type(expr);
  68         if (!type || type->type != SYM_PTR)
  69                 return -1;
  70 
  71         expr = strip_expr(expr);
  72         if (expr->type != EXPR_BINOP || expr->op != '-')
  73                 return -1;
  74 
  75         if (!get_value(expr->right, &sval))
  76                 return -1;
  77         if (sval.value < 0 || sval.value > 4096)
  78                 return -1;
  79 
  80         param_expr = get_assigned_expr(expr->left);
  81         if (!param_expr)
  82                 return -1;
  83 
  84         return sval.value;
  85 }
  86 
  87 static int get_container_arg(struct symbol *sym)
  88 {
  89         struct expression *__mptr;
  90         int param;
  91 
  92         if (!sym || !sym->ident)
  93                 return -1;
  94 
  95         __mptr = get_assigned_expr_name_sym(sym->ident->name, sym);
  96         param = get_param_from_container_of(__mptr);
  97 
  98         return param;
  99 }
 100 
 101 static int get_container_offset(struct symbol *sym)
 102 {
 103         struct expression *__mptr;
 104         int offset;
 105 
 106         if (!sym || !sym->ident)
 107                 return -1;
 108 
 109         __mptr = get_assigned_expr_name_sym(sym->ident->name, sym);
 110         offset = get_offset_from_container_of(__mptr);
 111 
 112         return offset;
 113 }
 114 
 115 static char *get_container_name_sm(struct sm_state *sm, int offset)
 116 {
 117         static char buf[256];
 118         const char *name;
 119 
 120         name = get_param_name(sm);
 121         if (!name)
 122                 return NULL;
 123 
 124         if (name[0] == '$')
 125                 snprintf(buf, sizeof(buf), "$(-%d)%s", offset, name + 1);
 126         else if (name[0] == '*' || name[1] == '$')
 127                 snprintf(buf, sizeof(buf), "*$(-%d)%s", offset, name + 2);
 128         else
 129                 return NULL;
 130 
 131         return buf;
 132 }
 133 
 134 static void get_state_hook(int owner, const char *name, struct symbol *sym)
 135 {
 136         int arg;
 137 
 138         if (!option_info)
 139                 return;
 140         if (__in_fake_assign)
 141                 return;
 142 
 143         arg = get_container_arg(sym);
 144         if (arg >= 0)
 145                 set_state_stree(&used_stree, my_id, name, sym, &used);
 146 }
 147 
 148 static void set_param_used(struct expression *call, struct expression *arg, char *key, char *unused)
 149 {
 150         struct symbol *sym;
 151         char *name;
 152         int arg_nr;
 153 
 154         name = get_variable_from_key(arg, key, &sym);
 155         if (!name || !sym)
 156                 goto free;
 157 
 158         arg_nr = get_container_arg(sym);
 159         if (arg_nr >= 0)
 160                 set_state(my_id, name, sym, &used);
 161 free:
 162         free_string(name);
 163 }
 164 
 165 static void process_states(void)
 166 {
 167         struct sm_state *tmp;
 168         int arg, offset;
 169         const char *name;
 170 
 171         FOR_EACH_SM(used_stree, tmp) {
 172                 arg = get_container_arg(tmp->sym);
 173                 offset = get_container_offset(tmp->sym);
 174                 if (arg < 0 || offset < 0)
 175                         continue;
 176                 name = get_container_name_sm(tmp, offset);
 177                 if (!name)
 178                         continue;
 179                 sql_insert_return_implies(CONTAINER, arg, name, "");
 180         } END_FOR_EACH_SM(tmp);
 181 
 182         free_stree(&used_stree);
 183 }
 184 
 185 static void match_function_def(struct symbol *sym)
 186 {
 187         free_stree(&used_stree);
 188 }
 189 
 190 static void match_save_states(struct expression *expr)
 191 {
 192         push_stree(&saved_stack, used_stree);
 193         used_stree = NULL;
 194 }
 195 
 196 static void match_restore_states(struct expression *expr)
 197 {
 198         free_stree(&used_stree);
 199         used_stree = pop_stree(&saved_stack);
 200 }
 201 
 202 static void print_returns_container_of(int return_id, char *return_ranges, struct expression *expr)
 203 {
 204         int offset;
 205         int param;
 206         char key[64];
 207         char value[64];
 208 
 209         param = get_param_from_container_of(expr);
 210         if (param < 0)
 211                 return;
 212         offset = get_offset_from_container_of(expr);
 213         if (offset < 0)
 214                 return;
 215 
 216         snprintf(key, sizeof(key), "%d", param);
 217         snprintf(value, sizeof(value), "-%d", offset);
 218 
 219         /* no need to add it to return_implies because it's not really param_used */
 220         sql_insert_return_states(return_id, return_ranges, CONTAINER, -1,
 221                         key, value);
 222 }
 223 
 224 static void returns_container_of(struct expression *expr, int param, char *key, char *value)
 225 {
 226         struct expression *call, *arg;
 227         int offset;
 228         char buf[64];
 229 
 230         if (expr->type != EXPR_ASSIGNMENT || expr->op != '=')
 231                 return;
 232         call = strip_expr(expr->right);
 233         if (call->type != EXPR_CALL)
 234                 return;
 235         if (param != -1)
 236                 return;
 237         param = atoi(key);
 238         offset = atoi(value);
 239 
 240         arg = get_argument_from_call_expr(call->args, param);
 241         if (!arg)
 242                 return;
 243         if (arg->type != EXPR_SYMBOL)
 244                 return;
 245         param = get_param_num(arg);
 246         if (param < 0)
 247                 return;
 248         snprintf(buf, sizeof(buf), "$(%d)", offset);
 249         sql_insert_return_implies(CONTAINER, param, buf, "");
 250 }
 251 
 252 static int get_deref_count(struct expression *expr)
 253 {
 254         int cnt = 0;
 255 
 256         while (expr && expr->type == EXPR_DEREF) {
 257                 expr = expr->deref;
 258                 if (expr->type == EXPR_PREOP && expr->op == '*')
 259                         expr = expr->unop;
 260                 cnt++;
 261                 if (cnt > 100)
 262                         return -1;
 263         }
 264         return cnt;
 265 }
 266 
 267 static struct expression *get_partial_deref(struct expression *expr, int cnt)
 268 {
 269         while (--cnt >= 0) {
 270                 if (!expr || expr->type != EXPR_DEREF)
 271                         return expr;
 272                 expr = expr->deref;
 273                 if (expr->type == EXPR_PREOP && expr->op == '*')
 274                         expr = expr->unop;
 275         }
 276         return expr;
 277 }
 278 
 279 static int partial_deref_to_offset_str(struct expression *expr, int cnt, char op, char *buf, int size)
 280 {
 281         int n, offset;
 282 
 283         if (cnt == 0)
 284                 return snprintf(buf, size, "%c0", op);
 285 
 286         n = 0;
 287         while (--cnt >= 0) {
 288                 offset = get_member_offset_from_deref(expr);
 289                 if (offset < 0)
 290                         return -1;
 291                 n += snprintf(buf + n, size - n, "%c%d", op, offset);
 292                 if (expr->type != EXPR_DEREF)
 293                         return -1;
 294                 expr = expr->deref;
 295                 if (expr->type == EXPR_PREOP && expr->op == '*')
 296                         expr = expr->unop;
 297         }
 298 
 299         return n;
 300 }
 301 
 302 static char *get_shared_str(struct expression *container, struct expression *expr)
 303 {
 304         struct expression *one, *two;
 305         int cont, exp, min, ret, n;
 306         static char buf[48];
 307 
 308         cont = get_deref_count(container);
 309         exp = get_deref_count(expr);
 310         if (cont < 0 || exp < 0)
 311                 return NULL;
 312 
 313         min = (cont < exp) ? cont : exp;
 314         while (min >= 0) {
 315                 one = get_partial_deref(container, cont - min);
 316                 two = get_partial_deref(expr, exp - min);
 317                 if (expr_equiv(one, two))
 318                         goto found;
 319                 min--;
 320         }
 321 
 322         return NULL;
 323 
 324 found:
 325         ret = partial_deref_to_offset_str(container, cont - min, '-', buf, sizeof(buf));
 326         if (ret < 0)
 327                 return NULL;
 328         n = ret;
 329         ret = partial_deref_to_offset_str(expr, exp - min, '+', buf + ret, sizeof(buf) - ret);
 330         if (ret < 0)
 331                 return NULL;
 332         n += ret;
 333         if (n >= sizeof(buf))
 334                 return NULL;
 335 
 336         return buf;
 337 }
 338 
 339 char *get_container_name(struct expression *container, struct expression *expr)
 340 {
 341         struct symbol *container_sym, *sym;
 342         struct expression *tmp;
 343         static char buf[64];
 344         char *shared;
 345         bool star;
 346         int cnt;
 347 
 348         container_sym = expr_to_sym(container);
 349         sym = expr_to_sym(expr);
 350         if (container_sym && container_sym == sym)
 351                 goto found;
 352 
 353         cnt = 0;
 354         while ((tmp = get_assigned_expr(expr))) {
 355                 expr = tmp;
 356                 if (cnt++ > 3)
 357                         break;
 358         }
 359 
 360         cnt = 0;
 361         while ((tmp = get_assigned_expr(container))) {
 362                 container = tmp;
 363                 if (cnt++ > 3)
 364                         break;
 365         }
 366 
 367 found:
 368         expr = strip_expr(expr);
 369         star = true;
 370         if (expr->type == EXPR_PREOP && expr->op == '&') {
 371                 expr = strip_expr(expr->unop);
 372                 star = false;
 373         }
 374 
 375         container_sym = expr_to_sym(container);
 376         if (!container_sym)
 377                 return NULL;
 378         sym = expr_to_sym(expr);
 379         if (!sym || container_sym != sym)
 380                 return NULL;
 381 
 382         shared = get_shared_str(container, expr);
 383         if (star)
 384                 snprintf(buf, sizeof(buf), "*(%s)", shared);
 385         else
 386                 snprintf(buf, sizeof(buf), "%s", shared);
 387 
 388         return buf;
 389 }
 390 
 391 static void match_call(struct expression *call)
 392 {
 393         struct expression *fn, *arg;
 394         char *name;
 395         int param;
 396 
 397         /*
 398          * We're trying to link the function with the parameter.  There are a
 399          * couple ways this can be passed:
 400          * foo->func(foo, ...);
 401          * foo->func(foo->x, ...);
 402          * foo->bar.func(&foo->bar, ...);
 403          * foo->bar->baz->func(foo, ...);
 404          *
 405          * So the method is basically to subtract the offsets until we get to
 406          * the common bit, then add the member offsets to get the parameter.
 407          *
 408          * If we're taking an address then the offset math is not stared,
 409          * otherwise it is.  Starred means dereferenced.
 410          */
 411         fn = strip_expr(call->fn);
 412 
 413         param = -1;
 414         FOR_EACH_PTR(call->args, arg) {
 415                 param++;
 416 
 417                 name = get_container_name(fn, arg);
 418                 if (!name)
 419                         continue;
 420 
 421                 sql_insert_caller_info(call, CONTAINER, param, name, "$(-1)");
 422         } END_FOR_EACH_PTR(arg);
 423 }
 424 
 425 static void db_passed_container(const char *name, struct symbol *sym, char *key, char *value)
 426 {
 427         set_state(param_id, name, sym, alloc_state_str(key));
 428 }
 429 
 430 struct db_info {
 431         struct symbol *arg;
 432         int prev_offset;
 433         struct range_list *rl;
 434         int star;
 435         struct stree *stree;
 436 };
 437 
 438 static struct symbol *get_member_from_offset(struct symbol *sym, int offset)
 439 {
 440         struct symbol *type, *tmp;
 441         int cur;
 442 
 443         type = get_real_base_type(sym);
 444         if (!type || type->type != SYM_PTR)
 445                 return NULL;
 446         type = get_real_base_type(type);
 447         if (!type || type->type != SYM_STRUCT)
 448                 return NULL;
 449 
 450         cur = 0;
 451         FOR_EACH_PTR(type->symbol_list, tmp) {
 452                 cur = ALIGN(cur, tmp->ctype.alignment);
 453                 if (offset == cur)
 454                         return tmp;
 455                 cur += type_bytes(tmp);
 456         } END_FOR_EACH_PTR(tmp);
 457         return NULL;
 458 }
 459 
 460 static struct symbol *get_member_type_from_offset(struct symbol *sym, int offset)
 461 {
 462         struct symbol *base_type;
 463         struct symbol *member;
 464 
 465         base_type = get_real_base_type(sym);
 466         if (base_type && base_type->type == SYM_PTR)
 467                 base_type = get_real_base_type(base_type);
 468         if (offset == 0 && base_type && base_type->type == SYM_BASETYPE)
 469                 return base_type;
 470 
 471         member = get_member_from_offset(sym, offset);
 472         if (!member)
 473                 return NULL;
 474         return get_real_base_type(member);
 475 }
 476 
 477 static const char *get_name_from_offset(struct symbol *arg, int offset)
 478 {
 479         struct symbol *member, *type;
 480         const char *name;
 481         static char fullname[256];
 482 
 483         name = arg->ident->name;
 484 
 485         type = get_real_base_type(arg);
 486         if (!type || type->type != SYM_PTR)
 487                 return name;
 488 
 489         type = get_real_base_type(type);
 490         if (!type)
 491                 return NULL;
 492         if (type->type != SYM_STRUCT) {
 493                 snprintf(fullname, sizeof(fullname), "*%s", name);
 494                 return fullname;
 495         }
 496 
 497         member = get_member_from_offset(arg, offset);
 498         if (!member)
 499                 return NULL;
 500 
 501         snprintf(fullname, sizeof(fullname), "%s->%s", name, member->ident->name);
 502         return fullname;
 503 }
 504 
 505 static void set_param_value(struct stree **stree, struct symbol *arg, int offset, struct range_list *rl)
 506 {
 507         const char *name;
 508 
 509         name = get_name_from_offset(arg, offset);
 510         if (!name)
 511                 return;
 512         set_state_stree(stree, SMATCH_EXTRA, name, arg, alloc_estate_rl(rl));
 513 }
 514 
 515 static int save_vals(void *_db_info, int argc, char **argv, char **azColName)
 516 {
 517         struct db_info *db_info = _db_info;
 518         struct symbol *type;
 519         struct range_list *rl;
 520         int offset = 0;
 521         const char *value;
 522 
 523         if (argc == 2) {
 524                 offset = atoi(argv[0]);
 525                 value = argv[1];
 526         } else {
 527                 value = argv[0];
 528         }
 529 
 530         if (db_info->prev_offset != -1 &&
 531             db_info->prev_offset != offset) {
 532                 set_param_value(&db_info->stree, db_info->arg, db_info->prev_offset, db_info->rl);
 533                 db_info->rl = NULL;
 534         }
 535 
 536         db_info->prev_offset = offset;
 537 
 538         type = get_real_base_type(db_info->arg);
 539         if (db_info->star)
 540                 goto found_type;
 541         if (type->type != SYM_PTR)
 542                 return 0;
 543         type = get_real_base_type(type);
 544         if (type->type == SYM_BASETYPE)
 545                 goto found_type;
 546         type = get_member_type_from_offset(db_info->arg, offset);
 547 found_type:
 548         str_to_rl(type, (char *)value, &rl);
 549         if (db_info->rl)
 550                 db_info->rl = rl_union(db_info->rl, rl);
 551         else
 552                 db_info->rl = rl;
 553 
 554         return 0;
 555 }
 556 
 557 static struct stree *load_tag_info_sym(mtag_t tag, struct symbol *arg, int arg_offset, int star)
 558 {
 559         struct db_info db_info = {
 560                 .arg = arg,
 561                 .prev_offset = -1,
 562                 .star = star,
 563         };
 564         struct symbol *type;
 565 
 566         if (!tag || !arg->ident)
 567                 return NULL;
 568 
 569         type = get_real_base_type(arg);
 570         if (!type)
 571                 return NULL;
 572         if (!star) {
 573                 if (type->type != SYM_PTR)
 574                         return NULL;
 575                 type = get_real_base_type(type);
 576                 if (!type)
 577                         return NULL;
 578         }
 579 
 580         if (star || type->type == SYM_BASETYPE) {
 581                 run_sql(save_vals, &db_info,
 582                         "select value from mtag_data where tag = %lld and offset = %d and type = %d;",
 583                         tag, arg_offset, DATA_VALUE);
 584         } else {  /* presumably the parameter is a struct pointer */
 585                 run_sql(save_vals, &db_info,
 586                         "select offset, value from mtag_data where tag = %lld and type = %d;",
 587                         tag, DATA_VALUE);
 588         }
 589 
 590         if (db_info.prev_offset != -1)
 591                 set_param_value(&db_info.stree, arg, db_info.prev_offset, db_info.rl);
 592 
 593         // FIXME: handle an offset correctly
 594         if (!star && !arg_offset) {
 595                 sval_t sval;
 596 
 597                 sval.type = get_real_base_type(arg);
 598                 sval.uvalue = tag;
 599                 set_state_stree(&db_info.stree, SMATCH_EXTRA, arg->ident->name, arg, alloc_estate_sval(sval));
 600         }
 601         return db_info.stree;
 602 }
 603 
 604 static void load_container_data(struct symbol *arg, const char *info)
 605 {
 606         mtag_t cur_tag, container_tag, arg_tag;
 607         int container_offset, arg_offset;
 608         char *p = (char *)info;
 609         struct sm_state *sm;
 610         struct stree *stree;
 611         bool star = 0;
 612 
 613         if (p[0] == '*') {
 614                 star = 1;
 615                 p += 2;
 616         }
 617 
 618         if (!get_toplevel_mtag(cur_func_sym, &cur_tag))
 619                 return;
 620 
 621         while (true) {
 622                 container_offset = strtoul(p, &p, 0);
 623                 if (local_debug)
 624                         sm_msg("%s: cur_tag = %llu container_offset = %d",
 625                                __func__, cur_tag, container_offset);
 626                 if (!mtag_map_select_container(cur_tag, container_offset, &container_tag))
 627                         return;
 628                 cur_tag = container_tag;
 629                 if (local_debug)
 630                         sm_msg("%s: container_tag = %llu p = '%s'",
 631                                __func__, container_tag, p);
 632                 if (!p)
 633                         return;
 634                 if (p[0] != '-')
 635                         break;
 636                 p++;
 637         }
 638 
 639         if (p[0] != '+')
 640                 return;
 641 
 642         p++;
 643         arg_offset = strtoul(p, &p, 0);
 644         if (p && *p && *p != ')')
 645                 return;
 646 
 647         if (!arg_offset || star) {
 648                 arg_tag = container_tag;
 649         } else {
 650                 if (!mtag_map_select_tag(container_tag, -arg_offset, &arg_tag))
 651                         return;
 652         }
 653 
 654         stree = load_tag_info_sym(arg_tag, arg, arg_offset, star);
 655         FOR_EACH_SM(stree, sm) {
 656                 set_state(sm->owner, sm->name, sm->sym, sm->state);
 657         } END_FOR_EACH_SM(sm);
 658         free_stree(&stree);
 659 }
 660 
 661 static void handle_passed_container(struct symbol *sym)
 662 {
 663         struct symbol *arg;
 664         struct smatch_state *state;
 665 
 666         FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, arg) {
 667                 state = get_state(param_id, arg->ident->name, arg);
 668                 if (!state || state == &merged)
 669                         continue;
 670                 load_container_data(arg, state->name);
 671         } END_FOR_EACH_PTR(arg);
 672 }
 673 
 674 void register_container_of(int id)
 675 {
 676         my_id = id;
 677 
 678         add_hook(&match_function_def, FUNC_DEF_HOOK);
 679 
 680         add_get_state_hook(&get_state_hook);
 681 
 682         add_hook(&match_save_states, INLINE_FN_START);
 683         add_hook(&match_restore_states, INLINE_FN_END);
 684 
 685         select_return_implies_hook(CONTAINER, &set_param_used);
 686         all_return_states_hook(&process_states);
 687 
 688         add_split_return_callback(&print_returns_container_of);
 689         select_return_states_hook(CONTAINER, &returns_container_of);
 690 
 691         add_hook(&match_call, FUNCTION_CALL_HOOK);
 692 }
 693 
 694 void register_container_of2(int id)
 695 {
 696         param_id = id;
 697 
 698         set_dynamic_states(param_id);
 699         select_caller_info_hook(db_passed_container, CONTAINER);
 700         add_merge_hook(param_id, &merge_str_state);
 701         add_hook(&handle_passed_container, AFTER_DEF_HOOK);
 702 }
 703