Comparison Between TreeValue and Jax LibTree

In this section, we will take a look at the feature and performance of the jax-libtree library, which is developed by Google.

_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}

Mapping Operation

TreeValue’s Mapping

from treevalue import mapping, FastTreeValue

t = FastTreeValue(_TREE_DATA_1)
mapping(t, lambda x: x ** 2)
%timeit mapping(t, lambda x: x ** 2)
3.59 µs ± 57.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
mapping(t, lambda x, p: (x ** 2, p))
%timeit mapping(t, lambda x, p: (x ** 2, p))
3.73 µs ± 23.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

pytree’s tree_map

from jax.tree_util import tree_map

tree_map(lambda x: x ** 2, _TREE_DATA_1)
{'a': 1, 'b': 4, 'x': {'c': 9, 'd': 16}}
%timeit tree_map(lambda x: x ** 2, _TREE_DATA_1)
6.56 µs ± 50.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Flatten and Unflatten Operation

TreeValue’s Performance

from treevalue import flatten, flatten_keys, flatten_values

t_flatted = flatten(t)
[(('a',), 1), (('b',), 2), (('x', 'c'), 3), (('x', 'd'), 4)]
%timeit flatten(t)
885 ns ± 7.13 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
from treevalue import flatten_keys

[('a',), ('b',), ('x', 'c'), ('x', 'd')]
%timeit flatten_keys(t)
753 ns ± 18.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
from treevalue import flatten_values

[1, 2, 3, 4]
%timeit flatten_values(t)
553 ns ± 5.12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
from treevalue import unflatten

%timeit unflatten(t_flatted)
982 ns ± 7.66 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

pytree’s Performance

from jax.tree_util import tree_flatten

leaves, treedef = tree_flatten(_TREE_DATA_1)
print('Leaves:', leaves)
print('Treedef:', treedef)
Leaves: [1, 2, 3, 4]
Treedef: PyTreeDef({'a': *, 'b': *, 'x': {'c': *, 'd': *}})
%timeit tree_flatten(_TREE_DATA_1)
2.25 µs ± 13.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
from jax.tree_util import tree_unflatten

tree_unflatten(treedef, leaves)
{'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
%timeit tree_unflatten(treedef, leaves)
1.04 µs ± 6.45 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

All Operation

TreeValue’s Performance

%timeit all(flatten_values(t))
707 ns ± 6.23 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

pytree.tree_all’s performance

from jax.tree_util import tree_all
%timeit tree_all(_TREE_DATA_1)
2.53 µs ± 6.82 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Reduce Operation

TreeValue’s Reduce

from functools import reduce

def _flatten_reduce(tree):
    values = flatten_values(tree)
    return reduce(lambda x, y: x + y, values)

%timeit _flatten_reduce(t)
1.31 µs ± 28 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
def _flatten_reduce_with_init(tree):
    values = flatten_values(tree)
    return reduce(lambda x, y: x + y, values, 0)

%timeit _flatten_reduce_with_init(t)
1.4 µs ± 23.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


from jax.tree_util import tree_reduce

tree_reduce(lambda x, y: x + y, _TREE_DATA_1)
%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1)
3.29 µs ± 66.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)
%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)
3.46 µs ± 60.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Structure Transpose

Subside and Rise in TreeValue

from treevalue import subside

value = {
    'a': FastTreeValue({'a': 1, 'b': {'x': 2, 'y': 3}}),
    'b': FastTreeValue({'a': 10, 'b': {'x': 20, 'y': 30}}),
    'c': {
        'x': FastTreeValue({'a': 100, 'b': {'x': 200, 'y': 300}}),
        'y': FastTreeValue({'a': 400, 'b': {'x': 500, 'y': 600}}),
%timeit subside(value)
17.2 µs ± 253 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
from treevalue import raw, rise

value = FastTreeValue({
    'a': raw({'a': 1, 'b': {'x': 2, 'y': 3}}),
    'b': raw({'a': 10, 'b': {'x': 20, 'y': 30}}),
    'c': {
        'x': raw({'a': 100, 'b': {'x': 200, 'y': 300}}),
        'y': raw({'a': 400, 'b': {'x': 500, 'y': 600}}),
{'b': {'x': <FastTreeValue 0x7efe186b8490>
  ├── 'a' --> 2
  ├── 'b' --> 20
  └── 'c' --> <FastTreeValue 0x7efe186b8250>
      ├── 'x' --> 200
      └── 'y' --> 500,
  'y': <FastTreeValue 0x7efe1869a760>
  ├── 'a' --> 3
  ├── 'b' --> 30
  └── 'c' --> <FastTreeValue 0x7efe186b8130>
      ├── 'x' --> 300
      └── 'y' --> 600},
 'a': <FastTreeValue 0x7efe1869a8e0>
 ├── 'a' --> 1
 ├── 'b' --> 10
 └── 'c' --> <FastTreeValue 0x7efe1869ac10>
     ├── 'x' --> 100
     └── 'y' --> 400}
%timeit rise(value)
18.4 µs ± 187 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
vt = {'a': None, 'b': {'x': None, 'y': None}}
rise(value, template=vt)
{'a': <FastTreeValue 0x7efe186b0280>
 ├── 'a' --> 1
 ├── 'b' --> 10
 └── 'c' --> <FastTreeValue 0x7efe186b0c70>
     ├── 'x' --> 100
     └── 'y' --> 400,
 'b': {'x': <FastTreeValue 0x7efe186b0460>
  ├── 'a' --> 2
  ├── 'b' --> 20
  └── 'c' --> <FastTreeValue 0x7efe186b02e0>
      ├── 'x' --> 200
      └── 'y' --> 500,
  'y': <FastTreeValue 0x7efe186b06d0>
  ├── 'a' --> 3
  ├── 'b' --> 30
  └── 'c' --> <FastTreeValue 0x7efe186b0b80>
      ├── 'x' --> 300
      └── 'y' --> 600}}
%timeit rise(value, template=vt)
14.6 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


from jax.tree_util import tree_structure, tree_transpose

sto = tree_structure({'a': 1, 'b': 2, 'c': {'x': 3, 'y': 4}})
sti = tree_structure({'a': 1, 'b': {'x': 2, 'y': 3}})

value = (
    {'a': 1, 'b': {'x': 2, 'y': 3}},
        'a': {'a': 10, 'b': {'x': 20, 'y': 30}},
        'b': [
            {'a': 100, 'b': {'x': 200, 'y': 300}},
            {'a': 400, 'b': {'x': 500, 'y': 600}},
tree_transpose(sto, sti, value)
{'a': {'a': 1, 'b': 10, 'c': {'x': 100, 'y': 400}},
 'b': {'x': {'a': 2, 'b': 20, 'c': {'x': 200, 'y': 500}},
  'y': {'a': 3, 'b': 30, 'c': {'x': 300, 'y': 600}}}}
%timeit tree_transpose(sto, sti, value)
16.3 µs ± 77.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
[ ]: