1 #!/usr/bin/python
   2 
   3 # Copyright (C) 2013 Oracle.
   4 #
   5 # Licensed under the Open Software License version 1.1
   6 
   7 import sqlite3
   8 import sys
   9 import re
  10 
  11 try:
  12     con = sqlite3.connect('smatch_db.sqlite')
  13 except sqlite3.Error, e:
  14     print "Error %s:" % e.args[0]
  15     sys.exit(1)
  16 
  17 def usage():
  18     print "%s" %(sys.argv[0])
  19     print "<function> - how a function is called"
  20     print "return_states <function> - what a function returns"
  21     print "call_tree <function> - show the call tree"
  22     print "where <struct_type> <member> - where a struct member is set"
  23     print "type_size <struct_type> <member> - how a struct member is allocated"
  24     print "data_info <struct_type> <member> - information about a given data type"
  25     print "function_ptr <function> - which function pointers point to this"
  26     print "trace_param <function> <param> - trace where a parameter came from"
  27     print "locals <file> - print the local values in a file."
  28     sys.exit(1)
  29 
  30 function_ptrs = []
  31 searched_ptrs = []
  32 def get_function_pointers_helper(func):
  33     cur = con.cursor()
  34     cur.execute("select distinct ptr from function_ptr where function = '%s';" %(func))
  35     for row in cur:
  36         ptr = row[0]
  37         if ptr in function_ptrs:
  38             continue
  39         function_ptrs.append(ptr)
  40         if not ptr in searched_ptrs:
  41             searched_ptrs.append(ptr)
  42             get_function_pointers_helper(ptr)
  43 
  44 def get_function_pointers(func):
  45     global function_ptrs
  46     global searched_ptrs
  47     function_ptrs = [func]
  48     searched_ptrs = [func]
  49     get_function_pointers_helper(func)
  50     return function_ptrs
  51 
  52 db_types = {   0: "INTERNAL",
  53              101: "PARAM_CLEARED",
  54              103: "PARAM_LIMIT",
  55              104: "PARAM_FILTER",
  56             1001: "PARAM_VALUE",
  57             1002: "BUF_SIZE",
  58             1003: "USER_DATA",
  59             1004: "CAPPED_DATA",
  60             1005: "RETURN_VALUE",
  61             1006: "DEREFERENCE",
  62             1007: "RANGE_CAP",
  63             1008: "LOCK_HELD",
  64             1009: "LOCK_RELEASED",
  65             1010: "ABSOLUTE_LIMITS",
  66             1012: "PARAM_ADD",
  67             1013: "PARAM_FREED",
  68             1014: "DATA_SOURCE",
  69             1015: "FUZZY_MAX",
  70             1016: "STR_LEN",
  71             1017: "ARRAY_LEN",
  72             1018: "CAPABLE",
  73             1019: "NS_CAPABLE",
  74             1022: "TYPE_LINK",
  75             1023: "UNTRACKED_PARAM",
  76             1024: "CULL_PATH",
  77             1025: "PARAM_SET",
  78             1026: "PARAM_USED",
  79             1027: "BYTE_UNITS",
  80             1028: "COMPARE_LIMIT",
  81             1029: "PARAM_COMPARE",
  82             8017: "USER_DATA2",
  83             8018: "NO_OVERFLOW",
  84             8019: "NO_OVERFLOW_SIMPLE",
  85             8020: "LOCKED",
  86             8021: "UNLOCKED",
  87             8023: "ATOMIC_INC",
  88             8024: "ATOMIC_DEC",
  89 };
  90 
  91 def add_range(rl, min_val, max_val):
  92     check_next = 0
  93     done = 0
  94     ret = []
  95     idx = 0
  96 
  97     if len(rl) == 0:
  98         return [[min_val, max_val]]
  99 
 100     for idx in range(len(rl)):
 101         cur_min = rl[idx][0]
 102         cur_max = rl[idx][1]
 103 
 104         # we already merged the new range but we might need to change later
 105         # ranges if they over lap with more than one
 106         if check_next:
 107             # join with added range
 108             if max_val + 1 == cur_min:
 109                 ret[len(ret) - 1][1] = cur_max
 110                 done = 1
 111                 break
 112             # don't overlap
 113             if max_val < cur_min:
 114                 ret.append([cur_min, cur_max])
 115                 done = 1
 116                 break
 117             # partially overlap
 118             if max_val < cur_max:
 119                 ret[len(ret) - 1][1] = cur_max
 120                 done = 1
 121                 break
 122             # completely overlap
 123             continue
 124 
 125         # join 2 ranges into one
 126         if max_val + 1 == cur_min:
 127             ret.append([min_val, cur_max])
 128             done = 1
 129             break
 130         # range is entirely below
 131         if max_val < cur_min:
 132             ret.append([min_val, max_val])
 133             ret.append([cur_min, cur_max])
 134             done = 1
 135             break
 136         # range is partially below
 137         if min_val < cur_min:
 138             if max_val <= cur_max:
 139                 ret.append([min_val, cur_max])
 140                 done = 1
 141                 break
 142             else:
 143                 ret.append([min_val, max_val])
 144                 check_next = 1
 145                 continue
 146         # range already included
 147         if max_val <= cur_max:
 148             ret.append([cur_min, cur_max])
 149             done = 1
 150             break;
 151         # range partially above
 152         if min_val <= cur_max:
 153             ret.append([cur_min, max_val])
 154             check_next = 1
 155             continue
 156         # join 2 ranges on the other side
 157         if min_val - 1 == cur_max:
 158             ret.append([cur_min, max_val])
 159             check_next = 1
 160             continue
 161         # range is above
 162         ret.append([cur_min, cur_max])
 163 
 164     if idx + 1 < len(rl):          # we hit a break statement
 165         ret = ret + rl[idx + 1:]
 166     elif done:                     # we hit a break on the last iteration
 167         pass
 168     elif not check_next:           # it's past the end of the rl
 169         ret.append([min_val, max_val])
 170 
 171     return ret;
 172 
 173 def rl_union(rl1, rl2):
 174     ret = []
 175     for r in rl1:
 176         ret = add_range(ret, r[0], r[1])
 177     for r in rl2:
 178         ret = add_range(ret, r[0], r[1])
 179 
 180     if (rl1 or rl2) and not ret:
 181         print "bug: merging %s + %s gives empty" %(rl1, rl2)
 182 
 183     return ret
 184 
 185 def txt_to_val(txt):
 186     if txt == "s64min":
 187         return -(2**63)
 188     elif txt == "s32min":
 189         return -(2**31)
 190     elif txt == "s16min":
 191         return -(2**15)
 192     elif txt == "s64max":
 193         return 2**63 - 1
 194     elif txt == "s32max":
 195         return 2**31 - 1
 196     elif txt == "s16max":
 197         return 2**15 - 1
 198     elif txt == "u64max":
 199         return 2**64 - 1
 200     elif txt == "u32max":
 201         return 2**32 - 1
 202     elif txt == "u16max":
 203         return 2**16 - 1
 204     else:
 205         try:
 206             return int(txt)
 207         except ValueError:
 208             return 0
 209 
 210 def val_to_txt(val):
 211     if val == -(2**63):
 212         return "s64min"
 213     elif val == -(2**31):
 214         return "s32min"
 215     elif val == -(2**15):
 216         return "s16min"
 217     elif val == 2**63 - 1:
 218         return "s64max"
 219     elif val == 2**31 - 1:
 220         return "s32max"
 221     elif val == 2**15 - 1:
 222         return "s16max"
 223     elif val == 2**64 - 1:
 224         return "u64max"
 225     elif val == 2**32 - 1:
 226         return "u32max"
 227     elif val == 2**16 - 1:
 228         return "u16max"
 229     elif val < 0:
 230         return "(%d)" %(val)
 231     else:
 232         return "%d" %(val)
 233 
 234 def get_next_str(txt):
 235     val = ""
 236     parsed = 0
 237 
 238     if txt[0] == '(':
 239         parsed += 1
 240         for char in txt[1:]:
 241             if char == ')':
 242                 break
 243             parsed += 1
 244         val = txt[1:parsed]
 245         parsed += 1
 246     elif txt[0] == 's' or txt[0] == 'u':
 247         parsed += 6
 248         val = txt[:parsed]
 249     else:
 250         if txt[0] == '-':
 251             parsed += 1
 252         for char in txt[parsed:]:
 253             if char == '-':
 254                 break
 255             parsed += 1
 256         val = txt[:parsed]
 257     return [parsed, val]
 258 
 259 def txt_to_rl(txt):
 260     if len(txt) == 0:
 261         return []
 262 
 263     ret = []
 264     pairs = txt.split(",")
 265     for pair in pairs:
 266         cnt, min_str = get_next_str(pair)
 267         if cnt == len(pair):
 268             max_str = min_str
 269         else:
 270             cnt, max_str = get_next_str(pair[cnt + 1:])
 271         min_val = txt_to_val(min_str)
 272         max_val = txt_to_val(max_str)
 273         ret.append([min_val, max_val])
 274 
 275 #    Hm...  Smatch won't call INT_MAX s32max if the variable is unsigned.
 276 #    if txt != rl_to_txt(ret):
 277 #        print "bug: converting: text = %s rl = %s internal = %s" %(txt, rl_to_txt(ret), ret)
 278 
 279     return ret
 280 
 281 def rl_to_txt(rl):
 282     ret = ""
 283     for idx in range(len(rl)):
 284         cur_min = rl[idx][0]
 285         cur_max = rl[idx][1]
 286 
 287         if idx != 0:
 288             ret += ","
 289 
 290         if cur_min == cur_max:
 291             ret += val_to_txt(cur_min)
 292         else:
 293             ret += val_to_txt(cur_min)
 294             ret += "-"
 295             ret += val_to_txt(cur_max)
 296     return ret
 297 
 298 def type_to_str(type_int):
 299 
 300     t = int(type_int)
 301     if db_types.has_key(t):
 302         return db_types[t]
 303     return type_int
 304 
 305 def type_to_int(type_string):
 306     for k in db_types.keys():
 307         if db_types[k] == type_string:
 308             return k
 309     return -1
 310 
 311 def display_caller_info(printed, cur, param_names):
 312     for txt in cur:
 313         if not printed:
 314             print "file | caller | function | type | parameter | key | value |"
 315         printed = 1
 316 
 317         parameter = int(txt[6])
 318         key = txt[7]
 319         if len(param_names) and parameter in param_names:
 320             key = key.replace("$", param_names[parameter])
 321 
 322         print "%20s | %20s | %20s |" %(txt[0], txt[1], txt[2]),
 323         print " %10s |" %(type_to_str(txt[5])),
 324         print " %d | %s | %s" %(parameter, key, txt[8])
 325     return printed
 326 
 327 def get_caller_info(filename, ptrs, my_type):
 328     cur = con.cursor()
 329     param_names = get_param_names(filename, func)
 330     printed = 0
 331     type_filter = ""
 332     if my_type != "":
 333         type_filter = "and type = %d" %(type_to_int(my_type))
 334     for ptr in ptrs:
 335         cur.execute("select * from caller_info where function = '%s' %s;" %(ptr, type_filter))
 336         printed = display_caller_info(printed, cur, param_names)
 337 
 338 def print_caller_info(filename, func, my_type = ""):
 339     ptrs = get_function_pointers(func)
 340     get_caller_info(filename, ptrs, my_type)
 341 
 342 def merge_values(param_names, vals, cur):
 343     for txt in cur:
 344         parameter = int(txt[0])
 345         name = txt[1]
 346         rl = txt_to_rl(txt[2])
 347         if parameter in param_names:
 348             name = name.replace("$", param_names[parameter])
 349 
 350         if not parameter in vals:
 351             vals[parameter] = {}
 352 
 353         # the first item on the list is the number of rows.  it's incremented
 354         # every time we call merge_values().
 355         if name in vals[parameter]:
 356             vals[parameter][name] = [vals[parameter][name][0] + 1, rl_union(vals[parameter][name][1], rl)]
 357         else:
 358             vals[parameter][name] = [1, rl]
 359 
 360 def get_param_names(filename, func):
 361     cur = con.cursor()
 362     param_names = {}
 363     cur.execute("select parameter, value from parameter_name where file = '%s' and function = '%s';" %(filename, func))
 364     for txt in cur:
 365         parameter = int(txt[0])
 366         name = txt[1]
 367         param_names[parameter] = name
 368     if len(param_names):
 369         return param_names
 370 
 371     cur.execute("select parameter, value from parameter_name where function = '%s';" %(func))
 372     for txt in cur:
 373         parameter = int(txt[0])
 374         name = txt[1]
 375         param_names[parameter] = name
 376     return param_names
 377 
 378 def get_caller_count(ptrs):
 379     cur = con.cursor()
 380     count = 0
 381     for ptr in ptrs:
 382         cur.execute("select count(distinct(call_id)) from caller_info where function = '%s';" %(ptr))
 383         for txt in cur:
 384             count += int(txt[0])
 385     return count
 386 
 387 def print_merged_caller_values(filename, func, ptrs, param_names, call_cnt):
 388     cur = con.cursor()
 389     vals = {}
 390     for ptr in ptrs:
 391         cur.execute("select parameter, key, value from caller_info where function = '%s' and type = %d;" %(ptr, type_to_int("PARAM_VALUE")))
 392         merge_values(param_names, vals, cur);
 393 
 394     for param in sorted(vals):
 395         for name in sorted(vals[param]):
 396             if vals[param][name][0] != call_cnt:
 397                 continue
 398             print "%d %s -> %s" %(param, name, rl_to_txt(vals[param][name][1]))
 399 
 400 
 401 def print_unmerged_caller_values(filename, func, ptrs, param_names):
 402     cur = con.cursor()
 403     for ptr in ptrs:
 404         prev = -1
 405         cur.execute("select file, caller, call_id, parameter, key, value from caller_info where function = '%s' and type = %d;" %(ptr, type_to_int("PARAM_VALUE")))
 406         for filename, caller, call_id, parameter, name, value in cur:
 407             if prev != int(call_id):
 408                 prev = int(call_id)
 409 
 410             parameter = int(parameter)
 411             if parameter < len(param_names):
 412                 name = name.replace("$", param_names[parameter])
 413             else:
 414                 name = name.replace("$", "$%d" %(parameter))
 415 
 416             print "%s | %s | %s | %s" %(filename, caller, name, value)
 417         print "=========================="
 418 
 419 def print_caller_values(filename, func, ptrs):
 420     param_names = get_param_names(filename, func)
 421     call_cnt = get_caller_count(ptrs)
 422 
 423     print_merged_caller_values(filename, func, ptrs, param_names, call_cnt)
 424     print "=========================="
 425     print_unmerged_caller_values(filename, func, ptrs, param_names)
 426 
 427 def caller_info_values(filename, func):
 428     ptrs = get_function_pointers(func)
 429     print_caller_values(filename, func, ptrs)
 430 
 431 def print_return_states(func):
 432     cur = con.cursor()
 433     cur.execute("select * from return_states where function = '%s';" %(func))
 434     count = 0
 435     for txt in cur:
 436         printed = 1
 437         if count == 0:
 438             print "file | function | return_id | return_value | type | param | key | value |"
 439         count += 1
 440         print "%s | %s | %2s | %13s" %(txt[0], txt[1], txt[3], txt[4]),
 441         print "| %13s |" %(type_to_str(txt[6])),
 442         print " %2d | %20s | %20s |" %(txt[7], txt[8], txt[9])
 443 
 444 def print_return_implies(func):
 445     cur = con.cursor()
 446     cur.execute("select * from return_implies where function = '%s';" %(func))
 447     count = 0
 448     for txt in cur:
 449         if not count:
 450             print "file | function | type | param | key | value |"
 451         count += 1
 452         print "%15s | %15s" %(txt[0], txt[1]),
 453         print "| %15s" %(type_to_str(txt[4])),
 454         print "| %3d | %s | %15s |" %(txt[5], txt[6], txt[7])
 455 
 456 def print_type_size(struct_type, member):
 457     cur = con.cursor()
 458     cur.execute("select * from type_size where type like '(struct %s)->%s';" %(struct_type, member))
 459     print "type | size"
 460     for txt in cur:
 461         print "%-15s | %s" %(txt[0], txt[1])
 462 
 463     cur.execute("select * from function_type_size where type like '(struct %s)->%s';" %(struct_type, member))
 464     print "file | function | type | size"
 465     for txt in cur:
 466         print "%-15s | %-15s | %-15s | %s" %(txt[0], txt[1], txt[2], txt[3])
 467 
 468 def print_data_info(struct_type, member):
 469     cur = con.cursor()
 470     cur.execute("select * from data_info where data like '(struct %s)->%s';" %(struct_type, member))
 471     print "file | data | type | value"
 472     for txt in cur:
 473         print "%-15s | %-15s | %-15s | %s" %(txt[0], txt[1], type_to_str(txt[2]), txt[3])
 474 
 475 def print_fn_ptrs(func):
 476     ptrs = get_function_pointers(func)
 477     if not ptrs:
 478         return
 479     print "%s = " %(func),
 480     print(ptrs)
 481 
 482 def print_functions(member):
 483     cur = con.cursor()
 484     cur.execute("select * from function_ptr where ptr like '%%->%s';" %(member))
 485     print "File | Pointer | Function | Static"
 486     for txt in cur:
 487         print "%-15s | %-15s | %-15s | %s" %(txt[0], txt[2], txt[1], txt[3])
 488 
 489 def get_callers(func):
 490     ret = []
 491     cur = con.cursor()
 492     ptrs = get_function_pointers(func)
 493     for ptr in ptrs:
 494         cur.execute("select distinct caller from caller_info where function = '%s';" %(ptr))
 495         for row in cur:
 496             ret.append(row[0])
 497     return ret
 498 
 499 printed_funcs = []
 500 def call_tree_helper(func, indent = 0):
 501     global printed_funcs
 502     if func in printed_funcs:
 503         return
 504     print "%s%s()" %(" " * indent, func)
 505     if func == "too common":
 506         return
 507     if indent > 6:
 508         return
 509     printed_funcs.append(func)
 510     callers = get_callers(func)
 511     if len(callers) >= 20:
 512         print "Over 20 callers for %s()" %(func)
 513         return
 514     for caller in callers:
 515         call_tree_helper(caller, indent + 2)
 516 
 517 def print_call_tree(func):
 518     global printed_funcs
 519     printed_funcs = []
 520     call_tree_helper(func)
 521 
 522 def function_type_value(struct_type, member):
 523     cur = con.cursor()
 524     cur.execute("select * from function_type_value where type like '(struct %s)->%s';" %(struct_type, member))
 525     for txt in cur:
 526         print "%-30s | %-30s | %s | %s" %(txt[0], txt[1], txt[2], txt[3])
 527 
 528 def trace_callers(func, param):
 529     sources = []
 530     prev_type = 0
 531 
 532     cur = con.cursor()
 533     ptrs = get_function_pointers(func)
 534     for ptr in ptrs:
 535         cur.execute("select type, caller, value from caller_info where function = '%s' and (type = 0 or type = 1014 or type = 1028) and (parameter = -1 or parameter = %d);" %(ptr, param))
 536         for row in cur:
 537             data_type = int(row[0])
 538             if data_type == 1014:
 539                 sources.append((row[1], row[2]))
 540             elif data_type == 1028:
 541                 sources.append(("%", row[2])) # hack...
 542             elif data_type == 0 and prev_type == 0:
 543                 sources.append((row[1], ""))
 544             prev_type = data_type
 545     return sources
 546 
 547 def trace_param_helper(func, param, indent = 0):
 548     global printed_funcs
 549     if func in printed_funcs:
 550         return
 551     print "%s%s(param %d)" %(" " * indent, func, param)
 552     if func == "too common":
 553         return
 554     if indent > 20:
 555         return
 556     printed_funcs.append(func)
 557     sources = trace_callers(func, param)
 558     for path in sources:
 559 
 560         if len(path[1]) and path[1][0] == 'p' and path[1][1] == ' ':
 561             p = int(path[1][2:])
 562             trace_param_helper(path[0], p, indent + 2)
 563         elif len(path[0]) and path[0][0] == '%':
 564             print "  %s%s" %(" " * indent, path[1])
 565         else:
 566             print "* %s%s %s" %(" " * (indent - 1), path[0], path[1])
 567 
 568 def trace_param(func, param):
 569     global printed_funcs
 570     printed_funcs = []
 571     print "tracing %s %d" %(func, param)
 572     trace_param_helper(func, param)
 573 
 574 def print_locals(filename):
 575     cur = con.cursor()
 576     cur.execute("select file,data,value from data_info where file = '%s' and type = 8029 and value != 0;" %(filename))
 577     for txt in cur:
 578         print "%s | %s | %s" %(txt[0], txt[1], txt[2])
 579 
 580 def constraint(struct_type, member):
 581     cur = con.cursor()
 582     cur.execute("select * from constraints_required where data like '(struct %s)->%s' or bound like '(struct %s)->%s';" %(struct_type, member, struct_type, member))
 583     for txt in cur:
 584         print "%-30s | %-30s | %s | %s" %(txt[0], txt[1], txt[2], txt[3])
 585 
 586 if len(sys.argv) < 2:
 587     usage()
 588 
 589 if len(sys.argv) == 2:
 590     func = sys.argv[1]
 591     print_caller_info("", func)
 592 elif sys.argv[1] == "call_info":
 593     if len(sys.argv) != 4:
 594         usage()
 595     filename = sys.argv[2]
 596     func = sys.argv[3]
 597     caller_info_values(filename, func)
 598     print_caller_info(filename, func)
 599 elif sys.argv[1] == "user_data":
 600     func = sys.argv[2]
 601     print_caller_info(filename, func, "USER_DATA")
 602 elif sys.argv[1] == "param_value":
 603     func = sys.argv[2]
 604     print_caller_info(filename, func, "PARAM_VALUE")
 605 elif sys.argv[1] == "function_ptr" or sys.argv[1] == "fn_ptr":
 606     func = sys.argv[2]
 607     print_fn_ptrs(func)
 608 elif sys.argv[1] == "return_states":
 609     func = sys.argv[2]
 610     print_return_states(func)
 611     print "================================================"
 612     print_return_implies(func)
 613 elif sys.argv[1] == "return_implies":
 614     func = sys.argv[2]
 615     print_return_implies(func)
 616 elif sys.argv[1] == "type_size" or sys.argv[1] == "buf_size":
 617     struct_type = sys.argv[2]
 618     member = sys.argv[3]
 619     print_type_size(struct_type, member)
 620 elif sys.argv[1] == "data_info":
 621     struct_type = sys.argv[2]
 622     member = sys.argv[3]
 623     print_data_info(struct_type, member)
 624 elif sys.argv[1] == "call_tree":
 625     func = sys.argv[2]
 626     print_call_tree(func)
 627 elif sys.argv[1] == "where":
 628     if len(sys.argv) == 3:
 629         struct_type = "%"
 630         member = sys.argv[2]
 631     elif len(sys.argv) == 4:
 632         struct_type = sys.argv[2]
 633         member = sys.argv[3]
 634     function_type_value(struct_type, member)
 635 elif sys.argv[1] == "local":
 636     filename = sys.argv[2]
 637     variable = ""
 638     if len(sys.argv) == 4:
 639         variable = sys.argv[3]
 640     local_values(filename, variable)
 641 elif sys.argv[1] == "functions":
 642     member = sys.argv[2]
 643     print_functions(member)
 644 elif sys.argv[1] == "trace_param":
 645     if len(sys.argv) != 4:
 646         usage()
 647     func = sys.argv[2]
 648     param = int(sys.argv[3])
 649     trace_param(func, param)
 650 elif sys.argv[1] == "locals":
 651     if len(sys.argv) != 3:
 652         usage()
 653     filename = sys.argv[2]
 654     print_locals(filename);
 655 elif sys.argv[1] == "constraint":
 656     if len(sys.argv) == 3:
 657         struct_type = "%"
 658         member = sys.argv[2]
 659     elif len(sys.argv) == 4:
 660         struct_type = sys.argv[2]
 661         member = sys.argv[3]
 662     constraint(struct_type, member)
 663 elif sys.argv[1] == "test":
 664     filename = sys.argv[2]
 665     func = sys.argv[3]
 666     caller_info_values(filename, func)
 667 else:
 668     usage()