$30 off During Our Annual Pro Sale. View Details »

Dispelling py.test magic

Dispelling py.test magic

How to replicate py.test magic in lest than 100 lines of code.
Demo code available here: https://github.com/oinopion/dispel

Tomek Paczkowski

September 19, 2015
Tweet

More Decks by Tomek Paczkowski

Other Decks in Programming

Transcript

  1. Dispelling py.test
    magic
    Tomek Paczkowski
    PyCon UK

    September 2015

    View Slide

  2. py.test is awesome
    Test discovery
    No boilerplate
    Plugins
    Asserting using assert statement

    View Slide

  3. def double(x):
    return x * 2
    def test_doubling():
    expected = 5
    assert double(2) == expected

    View Slide

  4. def test_doubling():
    expected = 5
    > assert double(2) == expected
    E assert 4 == 5
    E + where 4 = double(2)
    sample_test.py:6: AssertionError

    View Slide

  5. def test_doubling():
    expected = 5
    > assert double(2) == expected
    E assert 4 == 5
    E + where 4 = double(2)
    sample_test.py:6: AssertionError

    View Slide

  6. How to replicate this?

    View Slide

  7. Benjamin Peterson
    Behind the scenes of py.test's
    new assertion rewriting
    bit.ly/pytest-ast

    View Slide

  8. Step one
    import ast

    View Slide

  9. assert double(2) == expected
    assert
    ==
    func call
    variable constant
    variable

    View Slide

  10. ast.parse('assert double(2) == expected')
    bit.ly/docs-ast

    View Slide

  11. Assert(
    test=Compare(
    left=Call(
    func=Name(id='double'),
    args=[Num(n=2)]
    ),
    ops=[Eq()],
    comparators=[Name(id='expected')]
    )
    )

    View Slide

  12. Assert(
    test=Compare(
    left=Call(
    func=Name(id='double'),
    args=[Num(n=2)]
    ),
    ops=[Eq()],
    comparators=[Name(id='expected')]
    )
    )

    View Slide

  13. Assert(
    test=Compare(
    left=Call(
    func=Name(id='double'),
    args=[Num(n=2)]
    ),
    ops=[Eq()],
    comparators=[Name(id='expected')]
    )
    )

    View Slide

  14. Assert(
    test=Compare(
    left=Call(
    func=Name(id='double'),
    args=[Num(n=2)]
    ),
    ops=[Eq()],
    comparators=[Name(id='expected')]
    )
    )

    View Slide

  15. Assert(
    test=Compare(
    left=Call(
    func=Name(id='double'),
    args=[Num(n=2)]
    ),
    ops=[Eq()],
    comparators=[Name(id='expected')]
    )
    )

    View Slide

  16. Assert(
    test=Compare(
    left=Call(
    func=Name(id='double'),
    args=[Num(n=2)]
    ),
    ops=[Eq()],
    comparators=[Name(id='expected')]
    )
    )

    View Slide

  17. Assert(
    test=Compare(
    left=Call(
    func=Name(id='double'),
    args=[Num(n=2)]
    ),
    ops=[Eq()],
    comparators=[Name(id='expected')]
    )
    )
    bit.ly/better-docs-ast

    View Slide

  18. Goal one
    Modify the assert statement to

    call a function with both sides 

    of the comparison

    View Slide

  19. assert double(2) == expected
    assert_equals(double(2), expected)

    View Slide

  20. Node transformer
    Allows AST tree modification
    Implements visitor pattern

    View Slide

  21. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  22. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  23. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  24. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  25. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  26. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  27. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  28. class AssertRewrite(NodeTransformer):

    def visit_Assert(self, node):

    call = Call(

    func=Name(

    id='assert_equals', ctx=Load()

    ),

    args=[

    node.test.left,

    node.test.comparators[0]

    ],

    keywords=[]

    )

    new_node = Expr(value=call)

    copy_location(new_node, node)

    fix_missing_locations(new_node)

    return new_node

    View Slide

  29. Where is assert_equals
    coming from?

    View Slide

  30. def transform(module):

    import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', None)],

    lineno=0, col_offset=0,

    )

    module.body[0:0] = [import_node]

    transformer = AssertRewrite()

    return transformer.visit(module)

    View Slide

  31. def transform(module):

    import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', None)],

    lineno=0, col_offset=0,

    )

    module.body[0:0] = [import_node]

    transformer = AssertRewrite()

    return transformer.visit(module)

    View Slide

  32. def transform(module):

    import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', None)],

    lineno=0, col_offset=0,

    )

    module.body[0:0] = [import_node]

    transformer = AssertRewrite()

    return transformer.visit(module)

    View Slide

  33. def transform(module):

    import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', None)],

    lineno=0, col_offset=0,

    )

    module.body[0:0] = [import_node]

    transformer = AssertRewrite()

    return transformer.visit(module)

    View Slide

  34. def transform(module):

    import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', None)],

    lineno=0, col_offset=0,

    )

    module.body[0:0] = [import_node]

    transformer = AssertRewrite()

    return transformer.visit(module)

    View Slide

  35. def transform(module):

    import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', None)],

    lineno=0, col_offset=0,

    )

    module.body[0:0] = [import_node]

    transformer = AssertRewrite()

    return transformer.visit(module)

    View Slide

  36. import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', None)],

    lineno=0, col_offset=0,

    )

    View Slide

  37. import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', '#eq')],

    lineno=0, col_offset=0,

    )
    Call(

    func=Name(id='assert_equals', ctx=Load()),

    args=[...],

    keywords=[]

    )

    View Slide

  38. import_node = ImportFrom(

    module='test_utils',

    names=[alias('assert_equals', '#eq')],

    lineno=0, col_offset=0,

    )
    Call(

    func=Name(id='#eq', ctx=Load()),

    args=[...],

    keywords=[]

    )

    View Slide

  39. Step two
    import sys

    View Slide

  40. Import path hooks
    sys.path_hooks
    Factory functions for finders
    sys.path_importer_cache
    bit.ly/docs-import

    View Slide

  41. Goal two
    Write an import hook 

    that uses our transformer 

    to modify imported code

    View Slide

  42. def import_hook(path):
    if os.path.abspath('') == path:
    return Finder()
    else:
    raise ImportError
    sys.path_hooks.insert(0, import_hook)
    sys.path_importer_cache.clear()

    View Slide

  43. def import_hook(path):
    if os.path.abspath('') == path:
    return Finder()
    else:
    raise ImportError
    sys.path_hooks.insert(0, import_hook)
    sys.path_importer_cache.clear()

    View Slide

  44. def import_hook(path):
    if os.path.abspath('') == path:
    return Finder()
    else:
    raise ImportError
    sys.path_hooks.insert(0, import_hook)
    sys.path_importer_cache.clear()

    View Slide

  45. def import_hook(path):
    if os.path.abspath('') == path:
    return Finder()
    else:
    raise ImportError
    sys.path_hooks.insert(0, import_hook)
    sys.path_importer_cache.clear()

    View Slide

  46. def import_hook(path):
    if os.path.abspath('') == path:
    return Finder()
    else:
    raise ImportError
    sys.path_hooks.insert(0, import_hook)
    sys.path_importer_cache.clear()

    View Slide

  47. Finder
    Defines one method: find_spec
    Method returns a ModuleSpec

    View Slide

  48. from importlib.util import spec_from_file_location


    class Finder:

    def find_spec(self, module, target=None):

    file_name = module + '.py'

    if not os.path.exists(file_name):

    return None

    return spec_from_file_location(

    name=module, 

    location=file_name, 

    loader=Loader()

    )

    View Slide

  49. from importlib.util import spec_from_file_location


    class Finder:

    def find_spec(self, module, target=None):

    file_name = module + '.py'

    if not os.path.exists(file_name):

    return None

    return spec_from_file_location(

    name=module, 

    location=file_name, 

    loader=Loader()

    )

    View Slide

  50. from importlib.util import spec_from_file_location


    class Finder:

    def find_spec(self, module, target=None):

    file_name = module + '.py'

    if not os.path.exists(file_name):

    return None

    return spec_from_file_location(

    name=module, 

    location=file_name, 

    loader=Loader()

    )

    View Slide

  51. from importlib.util import spec_from_file_location


    class Finder:

    def find_spec(self, module, target=None):

    file_name = module + '.py'

    if not os.path.exists(file_name):

    return None

    return spec_from_file_location(

    name=module, 

    location=file_name, 

    loader=Loader()

    )

    View Slide

  52. from importlib.util import spec_from_file_location


    class Finder:

    def find_spec(self, module, target=None):

    file_name = module + '.py'

    if not os.path.exists(file_name):

    return None

    return spec_from_file_location(

    name=module, 

    location=file_name, 

    loader=Loader()

    )

    View Slide

  53. from importlib.util import spec_from_file_location


    class Finder:

    def find_spec(self, module, target=None):

    file_name = module + '.py'

    if not os.path.exists(file_name):

    return None

    return spec_from_file_location(

    name=module, 

    location=file_name, 

    loader=Loader()

    )

    View Slide

  54. Loader
    Defines one method: exec_module
    Executes module code
    Populates module namespace

    View Slide

  55. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as fp:
    source = fp.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  56. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as fp:
    source = fp.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  57. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as f:
    source = f.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  58. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as fp:
    source = fp.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  59. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as fp:
    source = fp.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  60. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as fp:
    source = fp.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  61. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as fp:
    source = fp.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  62. class Loader:
    def exec_module(self, module):
    with open(module.__file__, 'rb') as fp:
    source = fp.read()
    tree = ast.parse(source, module.__file__)
    tree = transform(tree)
    code = compile(tree, module.__file__, 'exec')
    exec(code, module.__dict__)

    View Slide

  63. class Loader:

    def exec_module(self, module):

    with open(module.__file__, 'rb') as fp:

    source = fp.read()


    tree = ast.parse(source, module.__file__)

    tree = transform(tree)

    code = compile(tree, module.__file__, 'exec')

    module.__dict__['#eq'] = assert_equals

    exec(code, module.__dict__)
    Bonus

    View Slide

  64. Step three
    Test discovery

    View Slide

  65. import sample_test
    sample_test.test_with_assert()

    View Slide

  66. Demo

    View Slide

  67. Summary
    This is probably a giant foot gun
    Corner cases left for the reader
    Python is awesome

    View Slide

  68. This presentation

    bit.ly/dispel-pytest
    Demo code

    github.com/oinopion/dispel

    View Slide

  69. Thanks!
    Tomek Paczkowski
    @oinopion

    View Slide