treevalue.tree.integration

register_for_jax

treevalue.tree.integration.register_for_jax(cls)
Overview:

Register treevalue class for jax.

Parameters:

cls – TreeValue class.

Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_jax
>>> register_for_jax(TreeValue)
>>> register_for_jax(FastTreeValue)

Warning

This method will put a warning message and then do nothing when jax is not installed.