# Copyright (c) 2018, Tom SF Haines

# Feel free to choose the MIT license:
# https://opensource.org/licenses/MIT
# or the BSD 2 clause license:
# https://opensource.org/licenses/BSD-2-Clause
# or ask if you would like another - I don't really care!



def instrument_loops(func):
  """A decorator that adjusts a function to include an extra keyword parameter, 'loops', to which a collections.defaultdict(int) *must* be passed, so it can count how many times it runs through each loop. Keys are line numbers within the function, i.e. it adds loops[line number] += 1 to the code where line number is 1 for the first line of the function etc. Has a very long list of failure modes - defining functions within functions for instance. Yes, this is madness; it is recommended you bite down on a stick before reading on."""

  # Fetch code and compile half-way, to an abstract syntax tree...
  source = inspect.getsource(func)
  tree = ast.parse(source)
  
  # Determine name of function...
  assert(isinstance(tree.body[0], ast.FunctionDef))
  func_name = tree.body[0].name
  
  # Need to remove self from decorator list or we are going to infinite loop (well, actually crash due to the lack of source code second time around, which breaks inspect)...
  tree.body[0].decorator_list = tree.body[0].decorator_list[:-1]
  
  # Drop the 'loops' keyword argument into the function definition...
  args = tree.body[0].args
  args.kwonlyargs.append(ast.arg(arg='loops', annotation=None))
  args.kw_defaults.append(ast.NameConstant(value=None))
  
  # Go for a walk and mess with every for and while loop...
  for node in ast.walk(tree):
    if isinstance(node, (ast.For, ast.While)):
      inc = ast.AugAssign(target=ast.Subscript(value=ast.Name(id='loops', ctx=ast.Load()), 
                                               slice=ast.Index(value=ast.Num(n=node.lineno-2)), ctx=ast.Store()),
                          op=ast.Add(),
                          value=ast.Num(n=1))
      node.body.insert(0, inc)
  
  # Fix line numbers...
  ast.fix_missing_locations(tree)
  
  # Finish compilation and return the new function...
  namespace = globals().copy() # Don't want to overwrite the original function - not this codes job
  code = compile(tree, '<instrumented_loops>', 'exec')
  exec(code, namespace)
  
  return namespace[func_name]

