{ "cells": [ { "cell_type": "markdown", "id": "1a1cab6c", "metadata": {}, "source": [ "# Comparison Between TreeValue and Jax LibTree" ] }, { "cell_type": "markdown", "id": "a9426f1e", "metadata": {}, "source": [ "In this section, we will take a look at the feature and performance of the [jax-libtree](https://jax.readthedocs.io/en/latest/pytrees.html) library, which is developed by Google." ] }, { "cell_type": "code", "execution_count": 1, "id": "0f4b6d16", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:05.433546Z", "iopub.status.busy": "2024-10-16T13:27:05.432987Z", "iopub.status.idle": "2024-10-16T13:27:05.440334Z", "shell.execute_reply": "2024-10-16T13:27:05.439795Z" } }, "outputs": [], "source": [ "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}" ] }, { "cell_type": "markdown", "id": "b1085059", "metadata": {}, "source": [ "## Mapping Operation" ] }, { "cell_type": "markdown", "id": "4ded33f7", "metadata": {}, "source": [ "### TreeValue's Mapping" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa9abbed", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:05.442649Z", "iopub.status.busy": "2024-10-16T13:27:05.442279Z", "iopub.status.idle": "2024-10-16T13:27:06.965034Z", "shell.execute_reply": "2024-10-16T13:27:06.964316Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/tmp4q2ykbbk/fbb7ad1854f8455286e2cedd40b3312212c7f17a/treevalue/tree/integration/torch.py:18: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n", " register_for_torch(TreeValue)\n", "/tmp/tmp4q2ykbbk/fbb7ad1854f8455286e2cedd40b3312212c7f17a/treevalue/tree/integration/torch.py:19: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n", " register_for_torch(FastTreeValue)\n" ] }, { "data": { "text/plain": [ "\n", "├── 'a' --> 1\n", "├── 'b' --> 4\n", "└── 'x' --> \n", " ├── 'c' --> 9\n", " └── 'd' --> 16" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import mapping, FastTreeValue\n", "\n", "t = FastTreeValue(_TREE_DATA_1)\n", "mapping(t, lambda x: x ** 2)" ] }, { "cell_type": "code", "execution_count": 3, "id": "4dc3420b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:06.967807Z", "iopub.status.busy": "2024-10-16T13:27:06.967365Z", "iopub.status.idle": "2024-10-16T13:27:09.095618Z", "shell.execute_reply": "2024-10-16T13:27:09.094932Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.62 µs ± 15.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit mapping(t, lambda x: x ** 2)" ] }, { "cell_type": "code", "execution_count": 4, "id": "289e14c6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:09.097820Z", "iopub.status.busy": "2024-10-16T13:27:09.097412Z", "iopub.status.idle": "2024-10-16T13:27:09.102223Z", "shell.execute_reply": "2024-10-16T13:27:09.101693Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> (1, ('a',))\n", "├── 'b' --> (4, ('b',))\n", "└── 'x' --> \n", " ├── 'c' --> (9, ('x', 'c'))\n", " └── 'd' --> (16, ('x', 'd'))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mapping(t, lambda x, p: (x ** 2, p))" ] }, { "cell_type": "code", "execution_count": 5, "id": "fce10189", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:09.104353Z", "iopub.status.busy": "2024-10-16T13:27:09.103975Z", "iopub.status.idle": "2024-10-16T13:27:11.295458Z", "shell.execute_reply": "2024-10-16T13:27:11.294803Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.7 µs ± 3.91 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit mapping(t, lambda x, p: (x ** 2, p))" ] }, { "cell_type": "markdown", "id": "18d4346d", "metadata": {}, "source": [ "### pytree's tree_map" ] }, { "cell_type": "code", "execution_count": 6, "id": "91b8e706", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:11.297556Z", "iopub.status.busy": "2024-10-16T13:27:11.297358Z", "iopub.status.idle": "2024-10-16T13:27:11.302183Z", "shell.execute_reply": "2024-10-16T13:27:11.301619Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': 1, 'b': 4, 'x': {'c': 9, 'd': 16}}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_map\n", "\n", "tree_map(lambda x: x ** 2, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "41fa094b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:11.304269Z", "iopub.status.busy": "2024-10-16T13:27:11.303892Z", "iopub.status.idle": "2024-10-16T13:27:15.111198Z", "shell.execute_reply": "2024-10-16T13:27:15.110548Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.66 µs ± 16.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_map(lambda x: x ** 2, _TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "36191201", "metadata": {}, "source": [ "## Flatten and Unflatten Operation" ] }, { "cell_type": "markdown", "id": "a733f250", "metadata": {}, "source": [ "### TreeValue's Performance" ] }, { "cell_type": "code", "execution_count": 8, "id": "676f03c0", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:15.113581Z", "iopub.status.busy": "2024-10-16T13:27:15.113154Z", "iopub.status.idle": "2024-10-16T13:27:15.118097Z", "shell.execute_reply": "2024-10-16T13:27:15.117521Z" } }, "outputs": [ { "data": { "text/plain": [ "[(('a',), 1), (('b',), 2), (('x', 'c'), 3), (('x', 'd'), 4)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten, flatten_keys, flatten_values\n", "\n", "t_flatted = flatten(t)\n", "t_flatted" ] }, { "cell_type": "code", "execution_count": 9, "id": "7ca0363d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:15.119919Z", "iopub.status.busy": "2024-10-16T13:27:15.119728Z", "iopub.status.idle": "2024-10-16T13:27:19.241996Z", "shell.execute_reply": "2024-10-16T13:27:19.241206Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "508 ns ± 5.59 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten(t)" ] }, { "cell_type": "code", "execution_count": 10, "id": "62d3bd2d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:19.244280Z", "iopub.status.busy": "2024-10-16T13:27:19.243871Z", "iopub.status.idle": "2024-10-16T13:27:19.248527Z", "shell.execute_reply": "2024-10-16T13:27:19.247854Z" } }, "outputs": [ { "data": { "text/plain": [ "[('a',), ('b',), ('x', 'c'), ('x', 'd')]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten_keys\n", "\n", "flatten_keys(t)" ] }, { "cell_type": "code", "execution_count": 11, "id": "58fbf8c4", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:19.250694Z", "iopub.status.busy": "2024-10-16T13:27:19.250313Z", "iopub.status.idle": "2024-10-16T13:27:22.956562Z", "shell.execute_reply": "2024-10-16T13:27:22.955910Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "457 ns ± 0.752 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten_keys(t)" ] }, { "cell_type": "code", "execution_count": 12, "id": "1448e341", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:22.958726Z", "iopub.status.busy": "2024-10-16T13:27:22.958504Z", "iopub.status.idle": "2024-10-16T13:27:22.962731Z", "shell.execute_reply": "2024-10-16T13:27:22.962207Z" } }, "outputs": [ { "data": { "text/plain": [ "[1, 2, 3, 4]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten_values\n", "\n", "flatten_values(t)" ] }, { "cell_type": "code", "execution_count": 13, "id": "956cf1c1", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:22.964664Z", "iopub.status.busy": "2024-10-16T13:27:22.964464Z", "iopub.status.idle": "2024-10-16T13:27:25.579519Z", "shell.execute_reply": "2024-10-16T13:27:25.578825Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "322 ns ± 2.83 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten_values(t)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f35e3a6d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:25.581844Z", "iopub.status.busy": "2024-10-16T13:27:25.581336Z", "iopub.status.idle": "2024-10-16T13:27:25.586094Z", "shell.execute_reply": "2024-10-16T13:27:25.585406Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> 1\n", "├── 'b' --> 2\n", "└── 'x' --> \n", " ├── 'c' --> 3\n", " └── 'd' --> 4" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import unflatten\n", "\n", "unflatten(t_flatted)" ] }, { "cell_type": "code", "execution_count": 15, "id": "52980ab8", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:25.588238Z", "iopub.status.busy": "2024-10-16T13:27:25.587841Z", "iopub.status.idle": "2024-10-16T13:27:30.341189Z", "shell.execute_reply": "2024-10-16T13:27:30.340465Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "586 ns ± 4.45 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit unflatten(t_flatted)" ] }, { "cell_type": "markdown", "id": "0ff80a6c", "metadata": {}, "source": [ "### pytree's Performance" ] }, { "cell_type": "code", "execution_count": 16, "id": "e89fc565", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:30.343745Z", "iopub.status.busy": "2024-10-16T13:27:30.343308Z", "iopub.status.idle": "2024-10-16T13:27:30.346972Z", "shell.execute_reply": "2024-10-16T13:27:30.346325Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Leaves: [1, 2, 3, 4]\n", "Treedef: PyTreeDef({'a': *, 'b': *, 'x': {'c': *, 'd': *}})\n" ] } ], "source": [ "from jax.tree_util import tree_flatten\n", "\n", "leaves, treedef = tree_flatten(_TREE_DATA_1)\n", "print('Leaves:', leaves)\n", "print('Treedef:', treedef)" ] }, { "cell_type": "code", "execution_count": 17, "id": "5d59c4d3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:30.349154Z", "iopub.status.busy": "2024-10-16T13:27:30.348773Z", "iopub.status.idle": "2024-10-16T13:27:41.732819Z", "shell.execute_reply": "2024-10-16T13:27:41.732184Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.4 µs ± 7.94 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_flatten(_TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 18, "id": "bcb1318c", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:41.735170Z", "iopub.status.busy": "2024-10-16T13:27:41.734720Z", "iopub.status.idle": "2024-10-16T13:27:41.739403Z", "shell.execute_reply": "2024-10-16T13:27:41.738701Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_unflatten\n", "\n", "tree_unflatten(treedef, leaves)" ] }, { "cell_type": "code", "execution_count": 19, "id": "638b5144", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:41.741471Z", "iopub.status.busy": "2024-10-16T13:27:41.741110Z", "iopub.status.idle": "2024-10-16T13:27:47.350292Z", "shell.execute_reply": "2024-10-16T13:27:47.349617Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "692 ns ± 3.41 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_unflatten(treedef, leaves)" ] }, { "cell_type": "markdown", "id": "ffd1522b", "metadata": {}, "source": [ "## All Operation" ] }, { "cell_type": "markdown", "id": "112c91a1", "metadata": {}, "source": [ "### TreeValue's Performance" ] }, { "cell_type": "code", "execution_count": 20, "id": "1b9c1c51", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:47.352531Z", "iopub.status.busy": "2024-10-16T13:27:47.352280Z", "iopub.status.idle": "2024-10-16T13:27:47.356913Z", "shell.execute_reply": "2024-10-16T13:27:47.356239Z" } }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all(flatten_values(t))" ] }, { "cell_type": "code", "execution_count": 21, "id": "fb8b0b2d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:47.358993Z", "iopub.status.busy": "2024-10-16T13:27:47.358639Z", "iopub.status.idle": "2024-10-16T13:27:50.831746Z", "shell.execute_reply": "2024-10-16T13:27:50.831086Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "428 ns ± 0.568 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit all(flatten_values(t))" ] }, { "cell_type": "markdown", "id": "86462abd", "metadata": {}, "source": [ "### pytree.tree_all's performance" ] }, { "cell_type": "code", "execution_count": 22, "id": "64d39484", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:50.834032Z", "iopub.status.busy": "2024-10-16T13:27:50.833650Z", "iopub.status.idle": "2024-10-16T13:27:50.836650Z", "shell.execute_reply": "2024-10-16T13:27:50.836125Z" } }, "outputs": [], "source": [ "from jax.tree_util import tree_all" ] }, { "cell_type": "code", "execution_count": 23, "id": "6a8427db", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:50.838551Z", "iopub.status.busy": "2024-10-16T13:27:50.838352Z", "iopub.status.idle": "2024-10-16T13:27:50.842375Z", "shell.execute_reply": "2024-10-16T13:27:50.841828Z" } }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_all(_TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 24, "id": "e4a4cd59", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:27:50.844192Z", "iopub.status.busy": "2024-10-16T13:27:50.843975Z", "iopub.status.idle": "2024-10-16T13:28:03.573173Z", "shell.execute_reply": "2024-10-16T13:28:03.572508Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.57 µs ± 6.85 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_all(_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "f0221182", "metadata": {}, "source": [ "## Reduce Operation" ] }, { "cell_type": "markdown", "id": "646b4a2e", "metadata": {}, "source": [ "### TreeValue's Reduce" ] }, { "cell_type": "code", "execution_count": 25, "id": "105e9001", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:03.575641Z", "iopub.status.busy": "2024-10-16T13:28:03.575243Z", "iopub.status.idle": "2024-10-16T13:28:03.580251Z", "shell.execute_reply": "2024-10-16T13:28:03.579717Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functools import reduce\n", "\n", "def _flatten_reduce(tree):\n", " values = flatten_values(tree)\n", " return reduce(lambda x, y: x + y, values)\n", "\n", "_flatten_reduce(t)" ] }, { "cell_type": "code", "execution_count": 26, "id": "71c145c0", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:03.582275Z", "iopub.status.busy": "2024-10-16T13:28:03.581942Z", "iopub.status.idle": "2024-10-16T13:28:10.171959Z", "shell.execute_reply": "2024-10-16T13:28:10.171224Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "814 ns ± 7.81 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit _flatten_reduce(t)" ] }, { "cell_type": "code", "execution_count": 27, "id": "20b34964", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:10.174092Z", "iopub.status.busy": "2024-10-16T13:28:10.173843Z", "iopub.status.idle": "2024-10-16T13:28:10.178947Z", "shell.execute_reply": "2024-10-16T13:28:10.178364Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def _flatten_reduce_with_init(tree):\n", " values = flatten_values(tree)\n", " return reduce(lambda x, y: x + y, values, 0)\n", "\n", "_flatten_reduce_with_init(t)" ] }, { "cell_type": "code", "execution_count": 28, "id": "ca8562c5", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:10.180896Z", "iopub.status.busy": "2024-10-16T13:28:10.180535Z", "iopub.status.idle": "2024-10-16T13:28:17.191774Z", "shell.execute_reply": "2024-10-16T13:28:17.191026Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "864 ns ± 1.24 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit _flatten_reduce_with_init(t)" ] }, { "cell_type": "markdown", "id": "12b2a28a", "metadata": {}, "source": [ "### pytree.tree_reduce" ] }, { "cell_type": "code", "execution_count": 29, "id": "14e9237e", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:17.193914Z", "iopub.status.busy": "2024-10-16T13:28:17.193695Z", "iopub.status.idle": "2024-10-16T13:28:17.198269Z", "shell.execute_reply": "2024-10-16T13:28:17.197708Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_reduce\n", "\n", "tree_reduce(lambda x, y: x + y, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 30, "id": "f6c5cfda", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:17.200419Z", "iopub.status.busy": "2024-10-16T13:28:17.200043Z", "iopub.status.idle": "2024-10-16T13:28:18.879660Z", "shell.execute_reply": "2024-10-16T13:28:18.878937Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.07 µs ± 43.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 31, "id": "e79fb73b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:18.881853Z", "iopub.status.busy": "2024-10-16T13:28:18.881643Z", "iopub.status.idle": "2024-10-16T13:28:18.886352Z", "shell.execute_reply": "2024-10-16T13:28:18.885740Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)" ] }, { "cell_type": "code", "execution_count": 32, "id": "af47dd74", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:18.888601Z", "iopub.status.busy": "2024-10-16T13:28:18.888139Z", "iopub.status.idle": "2024-10-16T13:28:20.618580Z", "shell.execute_reply": "2024-10-16T13:28:20.617924Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.13 µs ± 14.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)" ] }, { "cell_type": "markdown", "id": "6e58b70f", "metadata": {}, "source": [ "## Structure Transpose" ] }, { "cell_type": "markdown", "id": "a4a31f9e", "metadata": {}, "source": [ "### Subside and Rise in TreeValue" ] }, { "cell_type": "code", "execution_count": 33, "id": "b6af689f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:20.621026Z", "iopub.status.busy": "2024-10-16T13:28:20.620518Z", "iopub.status.idle": "2024-10-16T13:28:20.627002Z", "shell.execute_reply": "2024-10-16T13:28:20.626426Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> {'a': 1, 'b': 10, 'c': {'x': 100, 'y': 400}}\n", "└── 'b' --> \n", " ├── 'x' --> {'a': 2, 'b': 20, 'c': {'x': 200, 'y': 500}}\n", " └── 'y' --> {'a': 3, 'b': 30, 'c': {'x': 300, 'y': 600}}" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import subside\n", "\n", "value = {\n", " 'a': FastTreeValue({'a': 1, 'b': {'x': 2, 'y': 3}}),\n", " 'b': FastTreeValue({'a': 10, 'b': {'x': 20, 'y': 30}}),\n", " 'c': {\n", " 'x': FastTreeValue({'a': 100, 'b': {'x': 200, 'y': 300}}),\n", " 'y': FastTreeValue({'a': 400, 'b': {'x': 500, 'y': 600}}),\n", " },\n", "}\n", "subside(value)" ] }, { "cell_type": "code", "execution_count": 34, "id": "b85caa84", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:20.628839Z", "iopub.status.busy": "2024-10-16T13:28:20.628639Z", "iopub.status.idle": "2024-10-16T13:28:29.015122Z", "shell.execute_reply": "2024-10-16T13:28:29.014382Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.3 µs ± 54.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit subside(value)" ] }, { "cell_type": "code", "execution_count": 35, "id": "f304d286", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:29.017270Z", "iopub.status.busy": "2024-10-16T13:28:29.017048Z", "iopub.status.idle": "2024-10-16T13:28:29.023577Z", "shell.execute_reply": "2024-10-16T13:28:29.022918Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': \n", " ├── 'a' --> 1\n", " ├── 'b' --> 10\n", " └── 'c' --> \n", " ├── 'x' --> 100\n", " └── 'y' --> 400,\n", " 'b': {'y': \n", " ├── 'a' --> 3\n", " ├── 'b' --> 30\n", " └── 'c' --> \n", " ├── 'x' --> 300\n", " └── 'y' --> 600,\n", " 'x': \n", " ├── 'a' --> 2\n", " ├── 'b' --> 20\n", " └── 'c' --> \n", " ├── 'x' --> 200\n", " └── 'y' --> 500}}" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import raw, rise\n", "\n", "value = FastTreeValue({\n", " 'a': raw({'a': 1, 'b': {'x': 2, 'y': 3}}),\n", " 'b': raw({'a': 10, 'b': {'x': 20, 'y': 30}}),\n", " 'c': {\n", " 'x': raw({'a': 100, 'b': {'x': 200, 'y': 300}}),\n", " 'y': raw({'a': 400, 'b': {'x': 500, 'y': 600}}),\n", " },\n", "})\n", "rise(value)" ] }, { "cell_type": "code", "execution_count": 36, "id": "62d34321", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:29.025558Z", "iopub.status.busy": "2024-10-16T13:28:29.025198Z", "iopub.status.idle": "2024-10-16T13:28:38.150902Z", "shell.execute_reply": "2024-10-16T13:28:38.150173Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11.3 µs ± 28.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit rise(value)" ] }, { "cell_type": "code", "execution_count": 37, "id": "f0ed4793", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:38.153139Z", "iopub.status.busy": "2024-10-16T13:28:38.152754Z", "iopub.status.idle": "2024-10-16T13:28:38.157719Z", "shell.execute_reply": "2024-10-16T13:28:38.157043Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': \n", " ├── 'a' --> 1\n", " ├── 'b' --> 10\n", " └── 'c' --> \n", " ├── 'x' --> 100\n", " └── 'y' --> 400,\n", " 'b': {'x': \n", " ├── 'a' --> 2\n", " ├── 'b' --> 20\n", " └── 'c' --> \n", " ├── 'x' --> 200\n", " └── 'y' --> 500,\n", " 'y': \n", " ├── 'a' --> 3\n", " ├── 'b' --> 30\n", " └── 'c' --> \n", " ├── 'x' --> 300\n", " └── 'y' --> 600}}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vt = {'a': None, 'b': {'x': None, 'y': None}}\n", "rise(value, template=vt)" ] }, { "cell_type": "code", "execution_count": 38, "id": "a6ad3639", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:38.159841Z", "iopub.status.busy": "2024-10-16T13:28:38.159382Z", "iopub.status.idle": "2024-10-16T13:28:45.508546Z", "shell.execute_reply": "2024-10-16T13:28:45.507802Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.03 µs ± 65.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit rise(value, template=vt)" ] }, { "cell_type": "markdown", "id": "19ee26c6", "metadata": {}, "source": [ "### pytree.tree_transpose" ] }, { "cell_type": "code", "execution_count": 39, "id": "7f24f0f6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:45.510956Z", "iopub.status.busy": "2024-10-16T13:28:45.510477Z", "iopub.status.idle": "2024-10-16T13:28:45.517507Z", "shell.execute_reply": "2024-10-16T13:28:45.516864Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': {'a': 1, 'b': 10, 'c': {'x': 100, 'y': 400}},\n", " 'b': {'x': {'a': 2, 'b': 20, 'c': {'x': 200, 'y': 500}},\n", " 'y': {'a': 3, 'b': 30, 'c': {'x': 300, 'y': 600}}}}" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_structure, tree_transpose\n", "\n", "sto = tree_structure({'a': 1, 'b': 2, 'c': {'x': 3, 'y': 4}})\n", "sti = tree_structure({'a': 1, 'b': {'x': 2, 'y': 3}})\n", "\n", "value = (\n", " {'a': 1, 'b': {'x': 2, 'y': 3}},\n", " {\n", " 'a': {'a': 10, 'b': {'x': 20, 'y': 30}},\n", " 'b': [\n", " {'a': 100, 'b': {'x': 200, 'y': 300}},\n", " {'a': 400, 'b': {'x': 500, 'y': 600}},\n", " ],\n", " }\n", ")\n", "tree_transpose(sto, sti, value)" ] }, { "cell_type": "code", "execution_count": 40, "id": "d04072f3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:28:45.519730Z", "iopub.status.busy": "2024-10-16T13:28:45.519332Z", "iopub.status.idle": "2024-10-16T13:28:53.562408Z", "shell.execute_reply": "2024-10-16T13:28:53.561671Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.89 µs ± 49.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_transpose(sto, sti, value)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2a580a3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }