treevalue.tree.integration

register_for_jax

treevalue.tree.integration.register_for_jax(cls)
Overview:

Register treevalue class for jax.

Parameters:

cls – TreeValue class.

Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_jax
>>> register_for_jax(TreeValue)
>>> register_for_jax(FastTreeValue)

Warning

This method will put a warning message and then do nothing when jax is not installed.

register_for_torch

treevalue.tree.integration.register_for_torch(cls)
Overview:

Register treevalue class for torch’s pytree library.

Parameters:

cls – TreeValue class.

Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_torch
>>> register_for_torch(TreeValue)
>>> register_for_torch(FastTreeValue)

Warning

This method will put a warning message and then do nothing when torch is not installed.

register_treevalue_class

treevalue.tree.integration.register_treevalue_class(cls: Type[treevalue.tree.tree.tree.TreeValue], r_jax: bool = True, r_torch: bool = True)[source]
Overview:

Register treevalue class into all existing types.

Parameters:
  • cls – TreeValue class.

  • r_jax – Register for jax, default is True.

  • r_torch – Register for torch, default is True.

register_integrate_container

treevalue.tree.integration.register_integrate_container(type_, flatten_func, unflatten_func)
Overview:

Register custom data class for generic flatten and unflatten.

Parameters:
  • type – Class of data to be registered.

  • flatten_func – Function for flattening.

  • unflatten_func – Function for unflattening.

Examples::
>>> from treevalue import register_integrate_container, generic_flatten, FastTreeValue, generic_unflatten
>>> 
>>> class MyDC:
...     def __init__(self, x, y):
...         self.x = x
...         self.y = y
...
...     def __eq__(self, other):
...         return isinstance(other, MyDC) and self.x == other.x and self.y == other.y
>>> 
>>> def _mydc_flatten(v):
...     return [v.x, v.y], MyDC
>>> 
>>> def _mydc_unflatten(v, spec):  # spec will be MyDC
...     return spec(*v)
>>> 
>>> register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten)  # register MyDC
>>> 
>>> v, spec = generic_flatten({'a': MyDC(2, 3), 'b': MyDC((4, 5), FastTreeValue({'x': 1, 'y': 'f'}))})
>>> v
[[2, 3], [[4, 5], [1, 'f']]]
>>> 
>>> rt=generic_unflatten(v, spec)
>>> rt
{'a': <__main__.MyDC object at 0x7fbda613f9d0>, 'b': <__main__.MyDC object at 0x7fbda6148150>}
>>> rt['a'].x
2
>>> rt['a'].y
3
>>> rt['b'].x
(4, 5)
>>> rt['b'].y
<FastTreeValue 0x7fbda5aed510>
├── 'x' --> 1
└── 'y' --> 'f'

generic_flatten

treevalue.tree.integration.generic_flatten(v)
Overview:

Flatten generic data, including native objects, TreeValue, namedtuples and custom classes (see register_integrate_container()).

Parameters:

v – Value to be flatted.

Returns:

Flatted value.

Examples::
>>> from collections import namedtuple
>>> from easydict import EasyDict
>>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten
>>> 
>>> class MyTreeValue(FastTreeValue):
...     pass
>>> 
>>> nt = namedtuple('nt', ['a', 'b'])
>>> 
>>> origin = {
...     'a': 1,
...     'b': (2, 3, 'f',),
...     'c': (2, 5, 'ds', EasyDict({  # dict's child class
...         'x': None,
...         'z': [34, '1.2'],
...     })),
...     'd': nt('f', 100),  # namedtuple
...     'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})  # treevalue
... }
>>> v, spec = generic_flatten(origin)
>>> v
[1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
>>> 
>>> rv = generic_unflatten(v, spec)
>>> rv
{'a': 1, 'b': (2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fb6026d7b10>
├── 'x' --> 1
└── 'y' --> 'dsfljk'
}
>>> type(rv['c'][-1])
<class 'easydict.EasyDict'>

generic_unflatten

treevalue.tree.integration.generic_unflatten(v, gspec)
Overview:

Inverse operation of generic_flatten().

Parameters:
  • v – Flatted values.

  • gspec – Spec data of original object.

Examples::

See generic_flatten().

generic_mapping

treevalue.tree.integration.generic_mapping(v, func)
Overview:

Generic map all the values, including native objects, TreeValue, namedtuples and custom classes (see register_integrate_container())

Parameters:
  • v – Original value, nested structure is supported.

  • func – Function to operate.

Examples::
>>> from collections import namedtuple
>>> from easydict import EasyDict
>>> from treevalue import FastTreeValue, generic_mapping
>>> 
>>> class MyTreeValue(FastTreeValue):
...     pass
>>> 
>>> nt = namedtuple('nt', ['a', 'b'])
>>> 
>>> origin = {
...     'a': 1,
...     'b': (2, 3, 'f',),
...     'c': (2, 5, 'ds', EasyDict({  # dict's child class
...         'x': None,
...         'z': [34, '1.2'],
...     })),
...     'd': nt('f', 100),  # namedtuple
...     'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})  # treevalue
... }
>>> generic_mapping(origin, str)
{'a': '1', 'b': ('2', '3', 'f'), 'c': ('2', '5', 'ds', {'x': 'None', 'z': ['34', '1.2']}), 'd': nt(a='f', b='100'), 'e': <MyTreeValue 0x7f72e4d4ac90>
├── 'x' --> '1'
└── 'y' --> 'dsfljk'
}