1 /*
   2  * Copyright (C) 2017 Oracle.
   3  *
   4  * This program is free software; you can redistribute it and/or
   5  * modify it under the terms of the GNU General Public License
   6  * as published by the Free Software Foundation; either version 2
   7  * of the License, or (at your option) any later version.
   8  *
   9  * This program is distributed in the hope that it will be useful,
  10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  12  * GNU General Public License for more details.
  13  *
  14  * You should have received a copy of the GNU General Public License
  15  * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
  16  */
  17 
  18 #include "smatch.h"
  19 #include "smatch_slist.h"
  20 #include "smatch_extra.h"
  21 
  22 static int my_id;
  23 static int param_id;
  24 
  25 int get_param_from_container_of(struct expression *expr)
  26 {
  27         struct expression *param_expr;
  28         struct symbol *type;
  29         sval_t sval;
  30         int param;
  31 
  32 
  33         type = get_type(expr);
  34         if (!type || type->type != SYM_PTR)
  35                 return -1;
  36 
  37         expr = strip_expr(expr);
  38         if (expr->type != EXPR_BINOP || expr->op != '-')
  39                 return -1;
  40 
  41         if (!get_value(expr->right, &sval))
  42                 return -1;
  43         if (sval.value < 0 || sval.value > 4096)
  44                 return -1;
  45 
  46         param_expr = get_assigned_expr(expr->left);
  47         if (!param_expr)
  48                 return -1;
  49         param = get_param_num(param_expr);
  50         if (param < 0)
  51                 return -1;
  52 
  53         return param;
  54 }
  55 
  56 int get_offset_from_container_of(struct expression *expr)
  57 {
  58         struct expression *param_expr;
  59         struct symbol *type;
  60         sval_t sval;
  61 
  62         type = get_type(expr);
  63         if (!type || type->type != SYM_PTR)
  64                 return -1;
  65 
  66         expr = strip_expr(expr);
  67         if (expr->type != EXPR_BINOP || expr->op != '-')
  68                 return -1;
  69 
  70         if (!get_value(expr->right, &sval))
  71                 return -1;
  72         if (sval.value < 0 || sval.value > 4096)
  73                 return -1;
  74 
  75         param_expr = get_assigned_expr(expr->left);
  76         if (!param_expr)
  77                 return -1;
  78 
  79         return sval.value;
  80 }
  81 
  82 static void print_returns_container_of(int return_id, char *return_ranges, struct expression *expr)
  83 {
  84         int offset;
  85         int param;
  86         char key[64];
  87         char value[64];
  88 
  89         param = get_param_from_container_of(expr);
  90         if (param < 0)
  91                 return;
  92         offset = get_offset_from_container_of(expr);
  93         if (offset < 0)
  94                 return;
  95 
  96         snprintf(key, sizeof(key), "%d", param);
  97         snprintf(value, sizeof(value), "-%d", offset);
  98 
  99         /* no need to add it to return_implies because it's not really param_used */
 100         sql_insert_return_states(return_id, return_ranges, CONTAINER, -1,
 101                         key, value);
 102 }
 103 
 104 static int get_deref_count(struct expression *expr)
 105 {
 106         int cnt = 0;
 107 
 108         while (expr && expr->type == EXPR_DEREF) {
 109                 expr = expr->deref;
 110                 if (expr->type == EXPR_PREOP && expr->op == '*')
 111                         expr = expr->unop;
 112                 cnt++;
 113                 if (cnt > 100)
 114                         return -1;
 115         }
 116         return cnt;
 117 }
 118 
 119 static struct expression *get_partial_deref(struct expression *expr, int cnt)
 120 {
 121         while (--cnt >= 0) {
 122                 if (!expr || expr->type != EXPR_DEREF)
 123                         return expr;
 124                 expr = expr->deref;
 125                 if (expr->type == EXPR_PREOP && expr->op == '*')
 126                         expr = expr->unop;
 127         }
 128         return expr;
 129 }
 130 
 131 static int partial_deref_to_offset_str(struct expression *expr, int cnt, char op, char *buf, int size)
 132 {
 133         int n, offset;
 134 
 135         if (cnt == 0)
 136                 return snprintf(buf, size, "%c0", op);
 137 
 138         n = 0;
 139         while (--cnt >= 0) {
 140                 offset = get_member_offset_from_deref(expr);
 141                 if (offset < 0)
 142                         return -1;
 143                 n += snprintf(buf + n, size - n, "%c%d", op, offset);
 144                 if (expr->type != EXPR_DEREF)
 145                         return -1;
 146                 expr = expr->deref;
 147                 if (expr->type == EXPR_PREOP && expr->op == '*')
 148                         expr = expr->unop;
 149         }
 150 
 151         return n;
 152 }
 153 
 154 static char *get_shared_str(struct expression *expr, struct expression *container)
 155 {
 156         struct expression *one, *two;
 157         int exp, cont, min, ret, n;
 158         static char buf[48];
 159 
 160         exp = get_deref_count(expr);
 161         cont = get_deref_count(container);
 162         if (exp < 0 || cont < 0)
 163                 return NULL;
 164 
 165         min = (exp < cont) ? exp : cont;
 166         while (min >= 0) {
 167                 one = get_partial_deref(expr, exp - min);
 168                 two = get_partial_deref(container, cont - min);
 169                 if (expr_equiv(one, two))
 170                         goto found;
 171                 min--;
 172         }
 173 
 174         return NULL;
 175 
 176 found:
 177         ret = partial_deref_to_offset_str(expr, exp - min, '-', buf, sizeof(buf));
 178         if (ret < 0)
 179                 return NULL;
 180         n = ret;
 181         ret = partial_deref_to_offset_str(container, cont - min, '+', buf + ret, sizeof(buf) - ret);
 182         if (ret < 0)
 183                 return NULL;
 184         n += ret;
 185         if (n >= sizeof(buf))
 186                 return NULL;
 187 
 188         return buf;
 189 }
 190 
 191 static char *get_stored_container_name(struct expression *container,
 192                                        struct expression *expr)
 193 {
 194         struct smatch_state *state;
 195         static char buf[64];
 196         char *p;
 197         int param;
 198 
 199         if (!container || container->type != EXPR_SYMBOL)
 200                 return NULL;
 201         if (!expr || expr->type != EXPR_SYMBOL)
 202                 return NULL;
 203         state = get_state_expr(param_id, expr);
 204         if (!state)
 205                 return NULL;
 206 
 207         snprintf(buf, sizeof(buf), "%s", state->name);
 208         p = strchr(buf, '|');
 209         if (!p)
 210                 return NULL;
 211         *p = '\0';
 212         param = atoi(p + 2);
 213         if (get_param_sym_from_num(param) == container->symbol)
 214                 return buf;
 215         return NULL;
 216 }
 217 
 218 static char *get_container_name_helper(struct expression *container, struct expression *expr)
 219 {
 220         struct symbol *container_sym, *sym;
 221         static char buf[64];
 222         char *ret, *shared;
 223         bool star;
 224 
 225         expr = strip_expr(expr);
 226         container = strip_expr(container);
 227         if (!expr || !container)
 228                 return NULL;
 229 
 230         ret = get_stored_container_name(container, expr);
 231         if (ret)
 232                 return ret;
 233 
 234         sym = expr_to_sym(expr);
 235         container_sym = expr_to_sym(container);
 236         if (!sym || !container_sym)
 237                 return NULL;
 238         if (sym != container_sym)
 239                 return NULL;
 240 
 241         if (container->type == EXPR_DEREF)
 242                 star = true;
 243         else
 244                 star = false;
 245 
 246         if (container->type == EXPR_PREOP && container->op == '&')
 247                 container = strip_expr(container->unop);
 248         if (expr->type == EXPR_PREOP && expr->op == '&')
 249                 expr = strip_expr(expr->unop);
 250 
 251         shared = get_shared_str(expr, container);
 252         if (!shared)
 253                 return NULL;
 254         if (star)
 255                 snprintf(buf, sizeof(buf), "*(%s)", shared);
 256         else
 257                 snprintf(buf, sizeof(buf), "%s", shared);
 258 
 259         return buf;
 260 }
 261 
 262 char *get_container_name(struct expression *container, struct expression *expr)
 263 {
 264         char *ret;
 265 
 266         ret = get_container_name_helper(container, expr);
 267         if (ret)
 268                 return ret;
 269 
 270         ret = get_container_name_helper(get_assigned_expr(container), expr);
 271         if (ret)
 272                 return ret;
 273 
 274         ret = get_container_name_helper(container, get_assigned_expr(expr));
 275         if (ret)
 276                 return ret;
 277 
 278         ret = get_container_name_helper(get_assigned_expr(container),
 279                                         get_assigned_expr(expr));
 280         if (ret)
 281                 return ret;
 282 
 283         return NULL;
 284 }
 285 
 286 static bool is_fn_ptr(struct expression *expr)
 287 {
 288         struct symbol *type;
 289 
 290         if (!expr)
 291                 return false;
 292         if (expr->type != EXPR_SYMBOL && expr->type != EXPR_DEREF)
 293                 return false;
 294 
 295         type = get_type(expr);
 296         if (!type || type->type != SYM_PTR)
 297                 return false;
 298         type = get_real_base_type(type);
 299         if (!type || type->type != SYM_FN)
 300                 return false;
 301         return true;
 302 }
 303 
 304 static void match_call(struct expression *call)
 305 {
 306         struct expression *fn, *arg, *tmp;
 307         bool found = false;
 308         int fn_param, param;
 309         char buf[32];
 310         char *name;
 311 
 312         /*
 313          * We're trying to link the function with the parameter.  There are a
 314          * couple ways this can be passed:
 315          * foo->func(foo, ...);
 316          * foo->func(foo->x, ...);
 317          * foo->bar.func(&foo->bar, ...);
 318          * foo->bar->baz->func(foo, ...);
 319          *
 320          * So the method is basically to subtract the offsets until we get to
 321          * the common bit, then add the member offsets to get the parameter.
 322          *
 323          * If we're taking an address then the offset math is not stared,
 324          * otherwise it is.  Starred means dereferenced.
 325          */
 326         fn = strip_expr(call->fn);
 327 
 328         param = -1;
 329         FOR_EACH_PTR(call->args, arg) {
 330                 param++;
 331 
 332                 name = get_container_name(arg, fn);
 333                 if (!name)
 334                         continue;
 335 
 336                 found = true;
 337                 sql_insert_caller_info(call, CONTAINER, param, name, "$(-1)");
 338         } END_FOR_EACH_PTR(arg);
 339 
 340         if (found)
 341                 return;
 342 
 343         fn_param = -1;
 344         FOR_EACH_PTR(call->args, arg) {
 345                 fn_param++;
 346                 if (!is_fn_ptr(arg))
 347                         continue;
 348                 param = -1;
 349                 FOR_EACH_PTR(call->args, tmp) {
 350                         param++;
 351 
 352                         /* the function isn't it's own container */
 353                         if (arg == tmp)
 354                                 continue;
 355 
 356                         name = get_container_name(tmp, arg);
 357                         if (!name)
 358                                 continue;
 359 
 360                         snprintf(buf, sizeof(buf), "$%d", param);
 361                         sql_insert_caller_info(call, CONTAINER, fn_param, name, buf);
 362                         return;
 363                 } END_FOR_EACH_PTR(tmp);
 364         } END_FOR_EACH_PTR(arg);
 365 }
 366 
 367 static void db_passed_container(const char *name, struct symbol *sym, char *key, char *value)
 368 {
 369         char buf[64];
 370 
 371         snprintf(buf, sizeof(buf), "%s|%s", key, value);
 372         set_state(param_id, name, sym, alloc_state_str(buf));
 373 }
 374 
 375 struct db_info {
 376         struct symbol *arg;
 377         int prev_offset;
 378         struct range_list *rl;
 379         int star;
 380         struct stree *stree;
 381 };
 382 
 383 static struct symbol *get_member_from_offset(struct symbol *sym, int offset)
 384 {
 385         struct symbol *type, *tmp;
 386         int cur;
 387 
 388         type = get_real_base_type(sym);
 389         if (!type || type->type != SYM_PTR)
 390                 return NULL;
 391         type = get_real_base_type(type);
 392         if (!type || type->type != SYM_STRUCT)
 393                 return NULL;
 394 
 395         cur = 0;
 396         FOR_EACH_PTR(type->symbol_list, tmp) {
 397                 cur = ALIGN(cur, tmp->ctype.alignment);
 398                 if (offset == cur)
 399                         return tmp;
 400                 cur += type_bytes(tmp);
 401         } END_FOR_EACH_PTR(tmp);
 402         return NULL;
 403 }
 404 
 405 static struct symbol *get_member_type_from_offset(struct symbol *sym, int offset)
 406 {
 407         struct symbol *base_type;
 408         struct symbol *member;
 409 
 410         base_type = get_real_base_type(sym);
 411         if (base_type && base_type->type == SYM_PTR)
 412                 base_type = get_real_base_type(base_type);
 413         if (offset == 0 && base_type && base_type->type == SYM_BASETYPE)
 414                 return base_type;
 415 
 416         member = get_member_from_offset(sym, offset);
 417         if (!member)
 418                 return NULL;
 419         return get_real_base_type(member);
 420 }
 421 
 422 static const char *get_name_from_offset(struct symbol *arg, int offset)
 423 {
 424         struct symbol *member, *type;
 425         const char *name;
 426         static char fullname[256];
 427 
 428         name = arg->ident->name;
 429 
 430         type = get_real_base_type(arg);
 431         if (!type || type->type != SYM_PTR)
 432                 return name;
 433 
 434         type = get_real_base_type(type);
 435         if (!type)
 436                 return NULL;
 437         if (type->type != SYM_STRUCT) {
 438                 snprintf(fullname, sizeof(fullname), "*%s", name);
 439                 return fullname;
 440         }
 441 
 442         member = get_member_from_offset(arg, offset);
 443         if (!member || !member->ident)
 444                 return NULL;
 445 
 446         snprintf(fullname, sizeof(fullname), "%s->%s", name, member->ident->name);
 447         return fullname;
 448 }
 449 
 450 static void set_param_value(struct stree **stree, struct symbol *arg, int offset, struct range_list *rl)
 451 {
 452         const char *name;
 453 
 454         name = get_name_from_offset(arg, offset);
 455         if (!name)
 456                 return;
 457         set_state_stree(stree, SMATCH_EXTRA, name, arg, alloc_estate_rl(rl));
 458 }
 459 
 460 static int save_vals(void *_db_info, int argc, char **argv, char **azColName)
 461 {
 462         struct db_info *db_info = _db_info;
 463         struct symbol *type;
 464         struct range_list *rl;
 465         int offset = 0;
 466         const char *value;
 467 
 468         if (argc == 2) {
 469                 offset = atoi(argv[0]);
 470                 value = argv[1];
 471         } else {
 472                 value = argv[0];
 473         }
 474 
 475         if (db_info->prev_offset != -1 &&
 476             db_info->prev_offset != offset) {
 477                 set_param_value(&db_info->stree, db_info->arg, db_info->prev_offset, db_info->rl);
 478                 db_info->rl = NULL;
 479         }
 480 
 481         db_info->prev_offset = offset;
 482 
 483         type = get_real_base_type(db_info->arg);
 484         if (db_info->star)
 485                 goto found_type;
 486         if (type->type != SYM_PTR)
 487                 return 0;
 488         type = get_real_base_type(type);
 489         if (type->type == SYM_BASETYPE)
 490                 goto found_type;
 491         type = get_member_type_from_offset(db_info->arg, offset);
 492 found_type:
 493         str_to_rl(type, (char *)value, &rl);
 494         if (db_info->rl)
 495                 db_info->rl = rl_union(db_info->rl, rl);
 496         else
 497                 db_info->rl = rl;
 498 
 499         return 0;
 500 }
 501 
 502 static struct stree *load_tag_info_sym(mtag_t tag, struct symbol *arg, int arg_offset, int star)
 503 {
 504         struct db_info db_info = {
 505                 .arg = arg,
 506                 .prev_offset = -1,
 507                 .star = star,
 508         };
 509         struct symbol *type;
 510 
 511         if (!tag || !arg->ident)
 512                 return NULL;
 513 
 514         type = get_real_base_type(arg);
 515         if (!type)
 516                 return NULL;
 517         if (!star) {
 518                 if (type->type != SYM_PTR)
 519                         return NULL;
 520                 type = get_real_base_type(type);
 521                 if (!type)
 522                         return NULL;
 523         }
 524 
 525         if (star || type->type == SYM_BASETYPE) {
 526                 run_sql(save_vals, &db_info,
 527                         "select value from mtag_data where tag = %lld and offset = %d and type = %d;",
 528                         tag, arg_offset, DATA_VALUE);
 529         } else {  /* presumably the parameter is a struct pointer */
 530                 run_sql(save_vals, &db_info,
 531                         "select offset, value from mtag_data where tag = %lld and type = %d order by offset;",
 532                         tag, DATA_VALUE);
 533         }
 534 
 535         if (db_info.prev_offset != -1)
 536                 set_param_value(&db_info.stree, arg, db_info.prev_offset, db_info.rl);
 537 
 538         // FIXME: handle an offset correctly
 539         if (!star && !arg_offset) {
 540                 sval_t sval;
 541 
 542                 sval.type = get_real_base_type(arg);
 543                 sval.uvalue = tag;
 544                 set_state_stree(&db_info.stree, SMATCH_EXTRA, arg->ident->name, arg, alloc_estate_sval(sval));
 545         }
 546         return db_info.stree;
 547 }
 548 
 549 static void load_container_data(struct symbol *arg, const char *info)
 550 {
 551         mtag_t cur_tag, container_tag, arg_tag;
 552         int container_offset, arg_offset;
 553         struct sm_state *sm;
 554         struct stree *stree;
 555         char *p, *cont;
 556         char copy[64];
 557         bool star = 0;
 558 
 559         snprintf(copy, sizeof(copy), "%s", info);
 560         p = strchr(copy, '|');
 561         if (!p)
 562                 return;
 563         *p = '\0';
 564         cont = p + 1;
 565         p = copy;
 566         if (p[0] == '*') {
 567                 star = 1;
 568                 p += 2;
 569         }
 570 
 571         if (strcmp(cont, "$(-1)") != 0)
 572                 return;
 573 
 574         if (!get_toplevel_mtag(cur_func_sym, &cur_tag))
 575                 return;
 576 
 577         while (true) {
 578                 container_offset = strtoul(p, &p, 0);
 579                 if (local_debug)
 580                         sm_msg("%s: cur_tag = %llu container_offset = %d",
 581                                __func__, cur_tag, container_offset);
 582                 if (!mtag_map_select_container(cur_tag, -container_offset, &container_tag))
 583                         return;
 584                 cur_tag = container_tag;
 585                 if (local_debug)
 586                         sm_msg("%s: container_tag = %llu p = '%s'",
 587                                __func__, container_tag, p);
 588                 if (!p)
 589                         return;
 590                 if (p[0] != '-')
 591                         break;
 592                 p++;
 593         }
 594 
 595         if (p[0] != '+')
 596                 return;
 597 
 598         p++;
 599         arg_offset = strtoul(p, &p, 0);
 600         if (p && *p && *p != ')')
 601                 return;
 602 
 603         if (!arg_offset || star) {
 604                 arg_tag = container_tag;
 605         } else {
 606                 if (!mtag_map_select_tag(container_tag, arg_offset, &arg_tag))
 607                         return;
 608         }
 609 
 610         stree = load_tag_info_sym(arg_tag, arg, arg_offset, star);
 611         FOR_EACH_SM(stree, sm) {
 612                 set_state(sm->owner, sm->name, sm->sym, sm->state);
 613         } END_FOR_EACH_SM(sm);
 614         free_stree(&stree);
 615 }
 616 
 617 static void handle_passed_container(struct symbol *sym)
 618 {
 619         struct symbol *arg;
 620         struct smatch_state *state;
 621 
 622         FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, arg) {
 623                 state = get_state(param_id, arg->ident->name, arg);
 624                 if (!state || state == &merged)
 625                         continue;
 626                 load_container_data(arg, state->name);
 627         } END_FOR_EACH_PTR(arg);
 628 }
 629 
 630 void register_container_of(int id)
 631 {
 632         my_id = id;
 633 
 634         add_split_return_callback(&print_returns_container_of);
 635         add_hook(&match_call, FUNCTION_CALL_HOOK);
 636 }
 637 
 638 void register_container_of2(int id)
 639 {
 640         param_id = id;
 641 
 642         set_dynamic_states(param_id);
 643         select_caller_info_hook(db_passed_container, CONTAINER);
 644         add_merge_hook(param_id, &merge_str_state);
 645         add_hook(&handle_passed_container, AFTER_DEF_HOOK);
 646 }
 647