1 /*
   2  * Copyright (C) 2017 Oracle.
   3  *
   4  * This program is free software; you can redistribute it and/or
   5  * modify it under the terms of the GNU General Public License
   6  * as published by the Free Software Foundation; either version 2
   7  * of the License, or (at your option) any later version.
   8  *
   9  * This program is distributed in the hope that it will be useful,
  10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  12  * GNU General Public License for more details.
  13  *
  14  * You should have received a copy of the GNU General Public License
  15  * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
  16  */
  17 
  18 #include "smatch.h"
  19 #include "smatch_slist.h"
  20 #include "smatch_extra.h"
  21 
  22 static int my_id;
  23 static int param_id;
  24 
  25 int get_param_from_container_of(struct expression *expr)
  26 {
  27         struct expression *param_expr;
  28         struct symbol *type;
  29         sval_t sval;
  30         int param;
  31 
  32 
  33         type = get_type(expr);
  34         if (!type || type->type != SYM_PTR)
  35                 return -1;
  36 
  37         expr = strip_expr(expr);
  38         if (expr->type != EXPR_BINOP || expr->op != '-')
  39                 return -1;
  40 
  41         if (!get_value(expr->right, &sval))
  42                 return -1;
  43         if (sval.value < 0 || sval.value > 4096)
  44                 return -1;
  45 
  46         param_expr = get_assigned_expr(expr->left);
  47         if (!param_expr)
  48                 return -1;
  49         param = get_param_num(param_expr);
  50         if (param < 0)
  51                 return -1;
  52 
  53         return param;
  54 }
  55 
  56 int get_offset_from_container_of(struct expression *expr)
  57 {
  58         struct expression *param_expr;
  59         struct symbol *type;
  60         sval_t sval;
  61 
  62         type = get_type(expr);
  63         if (!type || type->type != SYM_PTR)
  64                 return -1;
  65 
  66         expr = strip_expr(expr);
  67         if (expr->type != EXPR_BINOP || expr->op != '-')
  68                 return -1;
  69 
  70         if (!get_value(expr->right, &sval))
  71                 return -1;
  72         if (sval.value < 0 || sval.value > 4096)
  73                 return -1;
  74 
  75         param_expr = get_assigned_expr(expr->left);
  76         if (!param_expr)
  77                 return -1;
  78 
  79         return sval.value;
  80 }
  81 
  82 static void print_returns_container_of(int return_id, char *return_ranges, struct expression *expr)
  83 {
  84         int offset;
  85         int param;
  86         char key[64];
  87         char value[64];
  88 
  89         param = get_param_from_container_of(expr);
  90         if (param < 0)
  91                 return;
  92         offset = get_offset_from_container_of(expr);
  93         if (offset < 0)
  94                 return;
  95 
  96         snprintf(key, sizeof(key), "%d", param);
  97         snprintf(value, sizeof(value), "-%d", offset);
  98 
  99         /* no need to add it to return_implies because it's not really param_used */
 100         sql_insert_return_states(return_id, return_ranges, CONTAINER, -1,
 101                         key, value);
 102 }
 103 
 104 static int get_deref_count(struct expression *expr)
 105 {
 106         int cnt = 0;
 107 
 108         while (expr && expr->type == EXPR_DEREF) {
 109                 expr = expr->deref;
 110                 if (expr->type == EXPR_PREOP && expr->op == '*')
 111                         expr = expr->unop;
 112                 cnt++;
 113                 if (cnt > 100)
 114                         return -1;
 115         }
 116         return cnt;
 117 }
 118 
 119 static struct expression *get_partial_deref(struct expression *expr, int cnt)
 120 {
 121         while (--cnt >= 0) {
 122                 if (!expr || expr->type != EXPR_DEREF)
 123                         return expr;
 124                 expr = expr->deref;
 125                 if (expr->type == EXPR_PREOP && expr->op == '*')
 126                         expr = expr->unop;
 127         }
 128         return expr;
 129 }
 130 
 131 static int partial_deref_to_offset_str(struct expression *expr, int cnt, char op, char *buf, int size)
 132 {
 133         int n, offset;
 134 
 135         if (cnt == 0)
 136                 return snprintf(buf, size, "%c0", op);
 137 
 138         n = 0;
 139         while (--cnt >= 0) {
 140                 offset = get_member_offset_from_deref(expr);
 141                 if (offset < 0)
 142                         return -1;
 143                 n += snprintf(buf + n, size - n, "%c%d", op, offset);
 144                 if (expr->type != EXPR_DEREF)
 145                         return -1;
 146                 expr = expr->deref;
 147                 if (expr->type == EXPR_PREOP && expr->op == '*')
 148                         expr = expr->unop;
 149         }
 150 
 151         return n;
 152 }
 153 
 154 static char *get_shared_str(struct expression *expr, struct expression *container)
 155 {
 156         struct expression *one, *two;
 157         int exp, cont, min, ret, n;
 158         static char buf[48];
 159 
 160         exp = get_deref_count(expr);
 161         cont = get_deref_count(container);
 162         if (exp < 0 || cont < 0)
 163                 return NULL;
 164 
 165         min = (exp < cont) ? exp : cont;
 166         while (min >= 0) {
 167                 one = get_partial_deref(expr, exp - min);
 168                 two = get_partial_deref(container, cont - min);
 169                 if (expr_equiv(one, two))
 170                         goto found;
 171                 min--;
 172         }
 173 
 174         return NULL;
 175 
 176 found:
 177         ret = partial_deref_to_offset_str(expr, exp - min, '-', buf, sizeof(buf));
 178         if (ret < 0)
 179                 return NULL;
 180         n = ret;
 181         ret = partial_deref_to_offset_str(container, cont - min, '+', buf + ret, sizeof(buf) - ret);
 182         if (ret < 0)
 183                 return NULL;
 184         n += ret;
 185         if (n >= sizeof(buf))
 186                 return NULL;
 187 
 188         return buf;
 189 }
 190 
 191 static char *get_stored_container_name(struct expression *container,
 192                                        struct expression *expr)
 193 {
 194         struct smatch_state *state;
 195         static char buf[64];
 196         char *p;
 197         int param;
 198 
 199         if (!container || container->type != EXPR_SYMBOL)
 200                 return NULL;
 201         if (!expr || expr->type != EXPR_SYMBOL)
 202                 return NULL;
 203         state = get_state_expr(param_id, expr);
 204         if (!state)
 205                 return NULL;
 206 
 207         snprintf(buf, sizeof(buf), "%s", state->name);
 208         p = strchr(buf, '|');
 209         if (!p)
 210                 return NULL;
 211         *p = '\0';
 212         param = atoi(p + 2);
 213         if (get_param_sym_from_num(param) == container->symbol)
 214                 return buf;
 215         return NULL;
 216 }
 217 
 218 char *get_container_name(struct expression *container, struct expression *expr)
 219 {
 220         struct symbol *container_sym, *sym;
 221         struct expression *tmp;
 222         static char buf[64];
 223         char *ret, *shared;
 224         bool star;
 225         int cnt;
 226 
 227         expr = strip_expr(expr);
 228         container = strip_expr(container);
 229 
 230         ret = get_stored_container_name(container, expr);
 231         if (ret)
 232                 return ret;
 233 
 234         sym = expr_to_sym(expr);
 235         container_sym = expr_to_sym(container);
 236         if (sym && sym == container_sym)
 237                 goto found;
 238 
 239         cnt = 0;
 240         while ((tmp = get_assigned_expr(container))) {
 241                 container = strip_expr(tmp);
 242                 if (cnt++ > 3)
 243                         break;
 244         }
 245 
 246         cnt = 0;
 247         while ((tmp = get_assigned_expr(expr))) {
 248                 expr = strip_expr(tmp);
 249                 if (cnt++ > 3)
 250                         break;
 251         }
 252 
 253 found:
 254 
 255         if (container->type == EXPR_DEREF)
 256                 star = true;
 257         else
 258                 star = false;
 259 
 260         if (container->type == EXPR_PREOP && container->op == '&')
 261                 container = strip_expr(container->unop);
 262         if (expr->type == EXPR_PREOP && expr->op == '&')
 263                 expr = strip_expr(expr->unop);
 264 
 265         sym = expr_to_sym(expr);
 266         if (!sym)
 267                 return NULL;
 268         container_sym = expr_to_sym(container);
 269         if (!container_sym || sym != container_sym)
 270                 return NULL;
 271 
 272         shared = get_shared_str(expr, container);
 273         if (!shared)
 274                 return NULL;
 275         if (star)
 276                 snprintf(buf, sizeof(buf), "*(%s)", shared);
 277         else
 278                 snprintf(buf, sizeof(buf), "%s", shared);
 279 
 280         return buf;
 281 }
 282 
 283 static bool is_fn_ptr(struct expression *expr)
 284 {
 285         struct symbol *type;
 286 
 287         if (!expr)
 288                 return false;
 289         if (expr->type != EXPR_SYMBOL && expr->type != EXPR_DEREF)
 290                 return false;
 291 
 292         type = get_type(expr);
 293         if (!type || type->type != SYM_PTR)
 294                 return false;
 295         type = get_real_base_type(type);
 296         if (!type || type->type != SYM_FN)
 297                 return false;
 298         return true;
 299 }
 300 
 301 static void match_call(struct expression *call)
 302 {
 303         struct expression *fn, *arg, *tmp;
 304         bool found = false;
 305         int fn_param, param;
 306         char buf[32];
 307         char *name;
 308 
 309         /*
 310          * We're trying to link the function with the parameter.  There are a
 311          * couple ways this can be passed:
 312          * foo->func(foo, ...);
 313          * foo->func(foo->x, ...);
 314          * foo->bar.func(&foo->bar, ...);
 315          * foo->bar->baz->func(foo, ...);
 316          *
 317          * So the method is basically to subtract the offsets until we get to
 318          * the common bit, then add the member offsets to get the parameter.
 319          *
 320          * If we're taking an address then the offset math is not stared,
 321          * otherwise it is.  Starred means dereferenced.
 322          */
 323         fn = strip_expr(call->fn);
 324 
 325         param = -1;
 326         FOR_EACH_PTR(call->args, arg) {
 327                 param++;
 328 
 329                 name = get_container_name(arg, fn);
 330                 if (!name)
 331                         continue;
 332 
 333                 found = true;
 334                 sql_insert_caller_info(call, CONTAINER, param, name, "$(-1)");
 335         } END_FOR_EACH_PTR(arg);
 336 
 337         if (found)
 338                 return;
 339 
 340         fn_param = -1;
 341         FOR_EACH_PTR(call->args, arg) {
 342                 fn_param++;
 343                 if (!is_fn_ptr(arg))
 344                         continue;
 345                 param = -1;
 346                 FOR_EACH_PTR(call->args, tmp) {
 347                         param++;
 348 
 349                         /* the function isn't it's own container */
 350                         if (arg == tmp)
 351                                 continue;
 352 
 353                         name = get_container_name(tmp, arg);
 354                         if (!name)
 355                                 continue;
 356 
 357                         snprintf(buf, sizeof(buf), "$%d", param);
 358                         sql_insert_caller_info(call, CONTAINER, fn_param, name, buf);
 359                         return;
 360                 } END_FOR_EACH_PTR(tmp);
 361         } END_FOR_EACH_PTR(arg);
 362 }
 363 
 364 static void db_passed_container(const char *name, struct symbol *sym, char *key, char *value)
 365 {
 366         char buf[64];
 367 
 368         snprintf(buf, sizeof(buf), "%s|%s", key, value);
 369         set_state(param_id, name, sym, alloc_state_str(buf));
 370 }
 371 
 372 struct db_info {
 373         struct symbol *arg;
 374         int prev_offset;
 375         struct range_list *rl;
 376         int star;
 377         struct stree *stree;
 378 };
 379 
 380 static struct symbol *get_member_from_offset(struct symbol *sym, int offset)
 381 {
 382         struct symbol *type, *tmp;
 383         int cur;
 384 
 385         type = get_real_base_type(sym);
 386         if (!type || type->type != SYM_PTR)
 387                 return NULL;
 388         type = get_real_base_type(type);
 389         if (!type || type->type != SYM_STRUCT)
 390                 return NULL;
 391 
 392         cur = 0;
 393         FOR_EACH_PTR(type->symbol_list, tmp) {
 394                 cur = ALIGN(cur, tmp->ctype.alignment);
 395                 if (offset == cur)
 396                         return tmp;
 397                 cur += type_bytes(tmp);
 398         } END_FOR_EACH_PTR(tmp);
 399         return NULL;
 400 }
 401 
 402 static struct symbol *get_member_type_from_offset(struct symbol *sym, int offset)
 403 {
 404         struct symbol *base_type;
 405         struct symbol *member;
 406 
 407         base_type = get_real_base_type(sym);
 408         if (base_type && base_type->type == SYM_PTR)
 409                 base_type = get_real_base_type(base_type);
 410         if (offset == 0 && base_type && base_type->type == SYM_BASETYPE)
 411                 return base_type;
 412 
 413         member = get_member_from_offset(sym, offset);
 414         if (!member)
 415                 return NULL;
 416         return get_real_base_type(member);
 417 }
 418 
 419 static const char *get_name_from_offset(struct symbol *arg, int offset)
 420 {
 421         struct symbol *member, *type;
 422         const char *name;
 423         static char fullname[256];
 424 
 425         name = arg->ident->name;
 426 
 427         type = get_real_base_type(arg);
 428         if (!type || type->type != SYM_PTR)
 429                 return name;
 430 
 431         type = get_real_base_type(type);
 432         if (!type)
 433                 return NULL;
 434         if (type->type != SYM_STRUCT) {
 435                 snprintf(fullname, sizeof(fullname), "*%s", name);
 436                 return fullname;
 437         }
 438 
 439         member = get_member_from_offset(arg, offset);
 440         if (!member || !member->ident)
 441                 return NULL;
 442 
 443         snprintf(fullname, sizeof(fullname), "%s->%s", name, member->ident->name);
 444         return fullname;
 445 }
 446 
 447 static void set_param_value(struct stree **stree, struct symbol *arg, int offset, struct range_list *rl)
 448 {
 449         const char *name;
 450 
 451         name = get_name_from_offset(arg, offset);
 452         if (!name)
 453                 return;
 454         set_state_stree(stree, SMATCH_EXTRA, name, arg, alloc_estate_rl(rl));
 455 }
 456 
 457 static int save_vals(void *_db_info, int argc, char **argv, char **azColName)
 458 {
 459         struct db_info *db_info = _db_info;
 460         struct symbol *type;
 461         struct range_list *rl;
 462         int offset = 0;
 463         const char *value;
 464 
 465         if (argc == 2) {
 466                 offset = atoi(argv[0]);
 467                 value = argv[1];
 468         } else {
 469                 value = argv[0];
 470         }
 471 
 472         if (db_info->prev_offset != -1 &&
 473             db_info->prev_offset != offset) {
 474                 set_param_value(&db_info->stree, db_info->arg, db_info->prev_offset, db_info->rl);
 475                 db_info->rl = NULL;
 476         }
 477 
 478         db_info->prev_offset = offset;
 479 
 480         type = get_real_base_type(db_info->arg);
 481         if (db_info->star)
 482                 goto found_type;
 483         if (type->type != SYM_PTR)
 484                 return 0;
 485         type = get_real_base_type(type);
 486         if (type->type == SYM_BASETYPE)
 487                 goto found_type;
 488         type = get_member_type_from_offset(db_info->arg, offset);
 489 found_type:
 490         str_to_rl(type, (char *)value, &rl);
 491         if (db_info->rl)
 492                 db_info->rl = rl_union(db_info->rl, rl);
 493         else
 494                 db_info->rl = rl;
 495 
 496         return 0;
 497 }
 498 
 499 static struct stree *load_tag_info_sym(mtag_t tag, struct symbol *arg, int arg_offset, int star)
 500 {
 501         struct db_info db_info = {
 502                 .arg = arg,
 503                 .prev_offset = -1,
 504                 .star = star,
 505         };
 506         struct symbol *type;
 507 
 508         if (!tag || !arg->ident)
 509                 return NULL;
 510 
 511         type = get_real_base_type(arg);
 512         if (!type)
 513                 return NULL;
 514         if (!star) {
 515                 if (type->type != SYM_PTR)
 516                         return NULL;
 517                 type = get_real_base_type(type);
 518                 if (!type)
 519                         return NULL;
 520         }
 521 
 522         if (star || type->type == SYM_BASETYPE) {
 523                 run_sql(save_vals, &db_info,
 524                         "select value from mtag_data where tag = %lld and offset = %d and type = %d;",
 525                         tag, arg_offset, DATA_VALUE);
 526         } else {  /* presumably the parameter is a struct pointer */
 527                 run_sql(save_vals, &db_info,
 528                         "select offset, value from mtag_data where tag = %lld and type = %d;",
 529                         tag, DATA_VALUE);
 530         }
 531 
 532         if (db_info.prev_offset != -1)
 533                 set_param_value(&db_info.stree, arg, db_info.prev_offset, db_info.rl);
 534 
 535         // FIXME: handle an offset correctly
 536         if (!star && !arg_offset) {
 537                 sval_t sval;
 538 
 539                 sval.type = get_real_base_type(arg);
 540                 sval.uvalue = tag;
 541                 set_state_stree(&db_info.stree, SMATCH_EXTRA, arg->ident->name, arg, alloc_estate_sval(sval));
 542         }
 543         return db_info.stree;
 544 }
 545 
 546 static void load_container_data(struct symbol *arg, const char *info)
 547 {
 548         mtag_t cur_tag, container_tag, arg_tag;
 549         int container_offset, arg_offset;
 550         struct sm_state *sm;
 551         struct stree *stree;
 552         char *p, *cont;
 553         char copy[64];
 554         bool star = 0;
 555 
 556         snprintf(copy, sizeof(copy), "%s", info);
 557         p = strchr(copy, '|');
 558         if (!p)
 559                 return;
 560         *p = '\0';
 561         cont = p + 1;
 562         p = copy;
 563         if (p[0] == '*') {
 564                 star = 1;
 565                 p += 2;
 566         }
 567 
 568         if (strcmp(cont, "$(-1)") != 0)
 569                 return;
 570 
 571         if (!get_toplevel_mtag(cur_func_sym, &cur_tag))
 572                 return;
 573 
 574         while (true) {
 575                 container_offset = strtoul(p, &p, 0);
 576                 if (local_debug)
 577                         sm_msg("%s: cur_tag = %llu container_offset = %d",
 578                                __func__, cur_tag, container_offset);
 579                 if (!mtag_map_select_container(cur_tag, container_offset, &container_tag))
 580                         return;
 581                 cur_tag = container_tag;
 582                 if (local_debug)
 583                         sm_msg("%s: container_tag = %llu p = '%s'",
 584                                __func__, container_tag, p);
 585                 if (!p)
 586                         return;
 587                 if (p[0] != '-')
 588                         break;
 589                 p++;
 590         }
 591 
 592         if (p[0] != '+')
 593                 return;
 594 
 595         p++;
 596         arg_offset = strtoul(p, &p, 0);
 597         if (p && *p && *p != ')')
 598                 return;
 599 
 600         if (!arg_offset || star) {
 601                 arg_tag = container_tag;
 602         } else {
 603                 if (!mtag_map_select_tag(container_tag, -arg_offset, &arg_tag))
 604                         return;
 605         }
 606 
 607         stree = load_tag_info_sym(arg_tag, arg, arg_offset, star);
 608         FOR_EACH_SM(stree, sm) {
 609                 set_state(sm->owner, sm->name, sm->sym, sm->state);
 610         } END_FOR_EACH_SM(sm);
 611         free_stree(&stree);
 612 }
 613 
 614 static void handle_passed_container(struct symbol *sym)
 615 {
 616         struct symbol *arg;
 617         struct smatch_state *state;
 618 
 619         FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, arg) {
 620                 state = get_state(param_id, arg->ident->name, arg);
 621                 if (!state || state == &merged)
 622                         continue;
 623                 load_container_data(arg, state->name);
 624         } END_FOR_EACH_PTR(arg);
 625 }
 626 
 627 void register_container_of(int id)
 628 {
 629         my_id = id;
 630 
 631         add_split_return_callback(&print_returns_container_of);
 632         add_hook(&match_call, FUNCTION_CALL_HOOK);
 633 }
 634 
 635 void register_container_of2(int id)
 636 {
 637         param_id = id;
 638 
 639         set_dynamic_states(param_id);
 640         select_caller_info_hook(db_passed_container, CONTAINER);
 641         add_merge_hook(param_id, &merge_str_state);
 642         add_hook(&handle_passed_container, AFTER_DEF_HOOK);
 643 }
 644