treevalue.tree.integration¶
register_for_jax¶
-
treevalue.tree.integration.
register_for_jax
(cls)¶ - Overview:
Register treevalue class for jax.
- Parameters:
cls – TreeValue class.
- Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_jax >>> register_for_jax(TreeValue) >>> register_for_jax(FastTreeValue)
Warning
This method will put a warning message and then do nothing when jax is not installed.
register_for_torch¶
-
treevalue.tree.integration.
register_for_torch
(cls)¶ - Overview:
Register treevalue class for torch’s pytree library.
- Parameters:
cls – TreeValue class.
- Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_torch >>> register_for_torch(TreeValue) >>> register_for_torch(FastTreeValue)
Warning
This method will put a warning message and then do nothing when torch is not installed.
register_treevalue_class¶
-
treevalue.tree.integration.
register_treevalue_class
(cls: Type[treevalue.tree.tree.tree.TreeValue], r_jax: bool = True, r_torch: bool = True)[source]¶ - Overview:
Register treevalue class into all existing types.
- Parameters:
cls – TreeValue class.
r_jax – Register for jax, default is True.
r_torch – Register for torch, default is True.
register_integrate_container¶
-
treevalue.tree.integration.
register_integrate_container
(type_, flatten_func, unflatten_func)¶ - Overview:
Register custom data class for generic flatten and unflatten.
- Parameters:
type – Class of data to be registered.
flatten_func – Function for flattening.
unflatten_func – Function for unflattening.
- Examples::
>>> from treevalue import register_integrate_container, generic_flatten, FastTreeValue, generic_unflatten >>> >>> class MyDC: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... ... def __eq__(self, other): ... return isinstance(other, MyDC) and self.x == other.x and self.y == other.y >>> >>> def _mydc_flatten(v): ... return [v.x, v.y], MyDC >>> >>> def _mydc_unflatten(v, spec): # spec will be MyDC ... return spec(*v)
>>> >>> register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) # register MyDC >>> >>> v, spec = generic_flatten({'a': MyDC(2, 3), 'b': MyDC((4, 5), FastTreeValue({'x': 1, 'y': 'f'}))}) >>> v [[2, 3], [[4, 5], [1, 'f']]] >>> >>> rt=generic_unflatten(v, spec) >>> rt {'a': <__main__.MyDC object at 0x7fbda613f9d0>, 'b': <__main__.MyDC object at 0x7fbda6148150>} >>> rt['a'].x 2 >>> rt['a'].y 3 >>> rt['b'].x (4, 5) >>> rt['b'].y <FastTreeValue 0x7fbda5aed510> ├── 'x' --> 1 └── 'y' --> 'f'
generic_flatten¶
-
treevalue.tree.integration.
generic_flatten
(v)¶ - Overview:
Flatten generic data, including native objects,
TreeValue
, namedtuples and custom classes (seeregister_integrate_container()
).
- Parameters:
v – Value to be flatted.
- Returns:
Flatted value.
- Examples::
>>> from collections import namedtuple >>> from easydict import EasyDict >>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten >>> >>> class MyTreeValue(FastTreeValue): ... pass >>> >>> nt = namedtuple('nt', ['a', 'b']) >>> >>> origin = { ... 'a': 1, ... 'b': (2, 3, 'f',), ... 'c': (2, 5, 'ds', EasyDict({ # dict's child class ... 'x': None, ... 'z': [34, '1.2'], # dataclass ... })), ... 'd': nt('f', 100), # namedtuple ... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue ... } >>> v, spec = generic_flatten(origin) >>> v [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] >>> >>> rv = generic_unflatten(v, spec) >>> rv {'a': 1, 'b': (2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fb6026d7b10> ├── 'x' --> 1 └── 'y' --> 'dsfljk' } >>> type(rv['c'][-1]) <class 'easydict.EasyDict'>
generic_unflatten¶
-
treevalue.tree.integration.
generic_unflatten
(v, gspec)¶ - Overview:
Inverse operation of
generic_flatten()
.
- Parameters:
v – Flatted values.
gspec – Spec data of original object.
- Examples::
See
generic_flatten()
.
generic_mapping¶
-
treevalue.tree.integration.
generic_mapping
(v, func)¶ - Overview:
Generic map all the values, including native objects,
TreeValue
, namedtuples and custom classes (seeregister_integrate_container()
)
- Parameters:
v – Original value, nested structure is supported.
func – Function to operate.
- Examples::
>>> from collections import namedtuple >>> from easydict import EasyDict >>> from treevalue import FastTreeValue, generic_mapping >>> >>> class MyTreeValue(FastTreeValue): ... pass >>> >>> nt = namedtuple('nt', ['a', 'b']) >>> >>> origin = { ... 'a': 1, ... 'b': (2, 3, 'f',), ... 'c': (2, 5, 'ds', EasyDict({ # dict's child class ... 'x': None, ... 'z': [34, '1.2'], # dataclass ... })), ... 'd': nt('f', 100), # namedtuple ... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue ... } >>> generic_mapping(origin, str) {'a': '1', 'b': ('2', '3', 'f'), 'c': ('2', '5', 'ds', {'x': 'None', 'z': ['34', '1.2']}), 'd': nt(a='f', b='100'), 'e': <MyTreeValue 0x7f72e4d4ac90> ├── 'x' --> '1' └── 'y' --> 'dsfljk' }