{ "cells": [ { "cell_type": "markdown", "id": "1a1cab6c", "metadata": {}, "source": [ "# Comparison Between TreeValue and Jax LibTree" ] }, { "cell_type": "markdown", "id": "a9426f1e", "metadata": {}, "source": [ "In this section, we will take a look at the feature and performance of the [jax-libtree](https://jax.readthedocs.io/en/latest/pytrees.html) library, which is developed by Google." ] }, { "cell_type": "code", "execution_count": 1, "id": "0f4b6d16", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:50.396419Z", "iopub.status.busy": "2024-10-16T13:06:50.396213Z", "iopub.status.idle": "2024-10-16T13:06:50.404103Z", "shell.execute_reply": "2024-10-16T13:06:50.403451Z" } }, "outputs": [], "source": [ "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}" ] }, { "cell_type": "markdown", "id": "b1085059", "metadata": {}, "source": [ "## Mapping Operation" ] }, { "cell_type": "markdown", "id": "4ded33f7", "metadata": {}, "source": [ "### TreeValue's Mapping" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa9abbed", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:50.406530Z", "iopub.status.busy": "2024-10-16T13:06:50.406185Z", "iopub.status.idle": "2024-10-16T13:06:50.446848Z", "shell.execute_reply": "2024-10-16T13:06:50.446124Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> 1\n", "├── 'b' --> 4\n", "└── 'x' --> \n", " ├── 'c' --> 9\n", " └── 'd' --> 16" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import mapping, FastTreeValue\n", "\n", "t = FastTreeValue(_TREE_DATA_1)\n", "mapping(t, lambda x: x ** 2)" ] }, { "cell_type": "code", "execution_count": 3, "id": "4dc3420b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:50.449019Z", "iopub.status.busy": "2024-10-16T13:06:50.448653Z", "iopub.status.idle": "2024-10-16T13:06:52.570353Z", "shell.execute_reply": "2024-10-16T13:06:52.569594Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.61 µs ± 10.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit mapping(t, lambda x: x ** 2)" ] }, { "cell_type": "code", "execution_count": 4, "id": "289e14c6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:52.572519Z", "iopub.status.busy": "2024-10-16T13:06:52.572122Z", "iopub.status.idle": "2024-10-16T13:06:52.576938Z", "shell.execute_reply": "2024-10-16T13:06:52.576287Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> (1, ('a',))\n", "├── 'b' --> (4, ('b',))\n", "└── 'x' --> \n", " ├── 'c' --> (9, ('x', 'c'))\n", " └── 'd' --> (16, ('x', 'd'))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mapping(t, lambda x, p: (x ** 2, p))" ] }, { "cell_type": "code", "execution_count": 5, "id": "fce10189", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:52.579281Z", "iopub.status.busy": "2024-10-16T13:06:52.578896Z", "iopub.status.idle": "2024-10-16T13:06:54.808543Z", "shell.execute_reply": "2024-10-16T13:06:54.807820Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.73 µs ± 32.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit mapping(t, lambda x, p: (x ** 2, p))" ] }, { "cell_type": "markdown", "id": "18d4346d", "metadata": {}, "source": [ "### pytree's tree_map" ] }, { "cell_type": "code", "execution_count": 6, "id": "91b8e706", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:54.810970Z", "iopub.status.busy": "2024-10-16T13:06:54.810568Z", "iopub.status.idle": "2024-10-16T13:06:55.145464Z", "shell.execute_reply": "2024-10-16T13:06:55.144790Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': 1, 'b': 4, 'x': {'c': 9, 'd': 16}}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_map\n", "\n", "tree_map(lambda x: x ** 2, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "41fa094b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:55.147732Z", "iopub.status.busy": "2024-10-16T13:06:55.147463Z", "iopub.status.idle": "2024-10-16T13:06:58.942972Z", "shell.execute_reply": "2024-10-16T13:06:58.942273Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.67 µs ± 16.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_map(lambda x: x ** 2, _TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "36191201", "metadata": {}, "source": [ "## Flatten and Unflatten Operation" ] }, { "cell_type": "markdown", "id": "a733f250", "metadata": {}, "source": [ "### TreeValue's Performance" ] }, { "cell_type": "code", "execution_count": 8, "id": "676f03c0", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:58.945533Z", "iopub.status.busy": "2024-10-16T13:06:58.945134Z", "iopub.status.idle": "2024-10-16T13:06:58.949845Z", "shell.execute_reply": "2024-10-16T13:06:58.949326Z" } }, "outputs": [ { "data": { "text/plain": [ "[(('a',), 1), (('b',), 2), (('x', 'c'), 3), (('x', 'd'), 4)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten, flatten_keys, flatten_values\n", "\n", "t_flatted = flatten(t)\n", "t_flatted" ] }, { "cell_type": "code", "execution_count": 9, "id": "7ca0363d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:58.951807Z", "iopub.status.busy": "2024-10-16T13:06:58.951602Z", "iopub.status.idle": "2024-10-16T13:07:03.014257Z", "shell.execute_reply": "2024-10-16T13:07:03.013530Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "501 ns ± 2.84 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten(t)" ] }, { "cell_type": "code", "execution_count": 10, "id": "62d3bd2d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:03.016471Z", "iopub.status.busy": "2024-10-16T13:07:03.016093Z", "iopub.status.idle": "2024-10-16T13:07:03.020706Z", "shell.execute_reply": "2024-10-16T13:07:03.020055Z" } }, "outputs": [ { "data": { "text/plain": [ "[('a',), ('b',), ('x', 'c'), ('x', 'd')]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten_keys\n", "\n", "flatten_keys(t)" ] }, { "cell_type": "code", "execution_count": 11, "id": "58fbf8c4", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:03.022729Z", "iopub.status.busy": "2024-10-16T13:07:03.022360Z", "iopub.status.idle": "2024-10-16T13:07:06.795147Z", "shell.execute_reply": "2024-10-16T13:07:06.794407Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "466 ns ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten_keys(t)" ] }, { "cell_type": "code", "execution_count": 12, "id": "1448e341", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:06.797277Z", "iopub.status.busy": "2024-10-16T13:07:06.797043Z", "iopub.status.idle": "2024-10-16T13:07:06.801617Z", "shell.execute_reply": "2024-10-16T13:07:06.800960Z" } }, "outputs": [ { "data": { "text/plain": [ "[1, 2, 3, 4]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten_values\n", "\n", "flatten_values(t)" ] }, { "cell_type": "code", "execution_count": 13, "id": "956cf1c1", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:06.803587Z", "iopub.status.busy": "2024-10-16T13:07:06.803240Z", "iopub.status.idle": "2024-10-16T13:07:09.442091Z", "shell.execute_reply": "2024-10-16T13:07:09.441306Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "325 ns ± 6.26 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten_values(t)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f35e3a6d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:09.444194Z", "iopub.status.busy": "2024-10-16T13:07:09.443966Z", "iopub.status.idle": "2024-10-16T13:07:09.448646Z", "shell.execute_reply": "2024-10-16T13:07:09.447968Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> 1\n", "├── 'b' --> 2\n", "└── 'x' --> \n", " ├── 'c' --> 3\n", " └── 'd' --> 4" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import unflatten\n", "\n", "unflatten(t_flatted)" ] }, { "cell_type": "code", "execution_count": 15, "id": "52980ab8", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:09.450980Z", "iopub.status.busy": "2024-10-16T13:07:09.450530Z", "iopub.status.idle": "2024-10-16T13:07:14.288981Z", "shell.execute_reply": "2024-10-16T13:07:14.288350Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "594 ns ± 3.69 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit unflatten(t_flatted)" ] }, { "cell_type": "markdown", "id": "0ff80a6c", "metadata": {}, "source": [ "### pytree's Performance" ] }, { "cell_type": "code", "execution_count": 16, "id": "e89fc565", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:14.291425Z", "iopub.status.busy": "2024-10-16T13:07:14.290992Z", "iopub.status.idle": "2024-10-16T13:07:14.294778Z", "shell.execute_reply": "2024-10-16T13:07:14.294124Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Leaves: [1, 2, 3, 4]\n", "Treedef: PyTreeDef({'a': *, 'b': *, 'x': {'c': *, 'd': *}})\n" ] } ], "source": [ "from jax.tree_util import tree_flatten\n", "\n", "leaves, treedef = tree_flatten(_TREE_DATA_1)\n", "print('Leaves:', leaves)\n", "print('Treedef:', treedef)" ] }, { "cell_type": "code", "execution_count": 17, "id": "5d59c4d3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:14.296883Z", "iopub.status.busy": "2024-10-16T13:07:14.296476Z", "iopub.status.idle": "2024-10-16T13:07:25.634508Z", "shell.execute_reply": "2024-10-16T13:07:25.633858Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.4 µs ± 5.15 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_flatten(_TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 18, "id": "bcb1318c", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:25.636832Z", "iopub.status.busy": "2024-10-16T13:07:25.636392Z", "iopub.status.idle": "2024-10-16T13:07:25.641207Z", "shell.execute_reply": "2024-10-16T13:07:25.640530Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_unflatten\n", "\n", "tree_unflatten(treedef, leaves)" ] }, { "cell_type": "code", "execution_count": 19, "id": "638b5144", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:25.643580Z", "iopub.status.busy": "2024-10-16T13:07:25.643127Z", "iopub.status.idle": "2024-10-16T13:07:31.396648Z", "shell.execute_reply": "2024-10-16T13:07:31.395907Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "709 ns ± 2.47 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_unflatten(treedef, leaves)" ] }, { "cell_type": "markdown", "id": "ffd1522b", "metadata": {}, "source": [ "## All Operation" ] }, { "cell_type": "markdown", "id": "112c91a1", "metadata": {}, "source": [ "### TreeValue's Performance" ] }, { "cell_type": "code", "execution_count": 20, "id": "1b9c1c51", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:31.399081Z", "iopub.status.busy": "2024-10-16T13:07:31.398853Z", "iopub.status.idle": "2024-10-16T13:07:31.403545Z", "shell.execute_reply": "2024-10-16T13:07:31.402873Z" } }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all(flatten_values(t))" ] }, { "cell_type": "code", "execution_count": 21, "id": "fb8b0b2d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:31.405588Z", "iopub.status.busy": "2024-10-16T13:07:31.405235Z", "iopub.status.idle": "2024-10-16T13:07:34.901250Z", "shell.execute_reply": "2024-10-16T13:07:34.900414Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "432 ns ± 1.22 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit all(flatten_values(t))" ] }, { "cell_type": "markdown", "id": "86462abd", "metadata": {}, "source": [ "### pytree.tree_all's performance" ] }, { "cell_type": "code", "execution_count": 22, "id": "64d39484", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:34.903499Z", "iopub.status.busy": "2024-10-16T13:07:34.903272Z", "iopub.status.idle": "2024-10-16T13:07:34.906593Z", "shell.execute_reply": "2024-10-16T13:07:34.905911Z" } }, "outputs": [], "source": [ "from jax.tree_util import tree_all" ] }, { "cell_type": "code", "execution_count": 23, "id": "6a8427db", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:34.908639Z", "iopub.status.busy": "2024-10-16T13:07:34.908183Z", "iopub.status.idle": "2024-10-16T13:07:34.912518Z", "shell.execute_reply": "2024-10-16T13:07:34.911890Z" } }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_all(_TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 24, "id": "e4a4cd59", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:34.914572Z", "iopub.status.busy": "2024-10-16T13:07:34.914202Z", "iopub.status.idle": "2024-10-16T13:07:47.713550Z", "shell.execute_reply": "2024-10-16T13:07:47.712931Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.58 µs ± 15.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_all(_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "f0221182", "metadata": {}, "source": [ "## Reduce Operation" ] }, { "cell_type": "markdown", "id": "646b4a2e", "metadata": {}, "source": [ "### TreeValue's Reduce" ] }, { "cell_type": "code", "execution_count": 25, "id": "105e9001", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:47.715963Z", "iopub.status.busy": "2024-10-16T13:07:47.715553Z", "iopub.status.idle": "2024-10-16T13:07:47.720503Z", "shell.execute_reply": "2024-10-16T13:07:47.719981Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functools import reduce\n", "\n", "def _flatten_reduce(tree):\n", " values = flatten_values(tree)\n", " return reduce(lambda x, y: x + y, values)\n", "\n", "_flatten_reduce(t)" ] }, { "cell_type": "code", "execution_count": 26, "id": "71c145c0", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:47.722318Z", "iopub.status.busy": "2024-10-16T13:07:47.722131Z", "iopub.status.idle": "2024-10-16T13:07:54.338354Z", "shell.execute_reply": "2024-10-16T13:07:54.337655Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "815 ns ± 3.75 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit _flatten_reduce(t)" ] }, { "cell_type": "code", "execution_count": 27, "id": "20b34964", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:54.340430Z", "iopub.status.busy": "2024-10-16T13:07:54.340230Z", "iopub.status.idle": "2024-10-16T13:07:54.345101Z", "shell.execute_reply": "2024-10-16T13:07:54.344557Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def _flatten_reduce_with_init(tree):\n", " values = flatten_values(tree)\n", " return reduce(lambda x, y: x + y, values, 0)\n", "\n", "_flatten_reduce_with_init(t)" ] }, { "cell_type": "code", "execution_count": 28, "id": "ca8562c5", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:07:54.347189Z", "iopub.status.busy": "2024-10-16T13:07:54.346798Z", "iopub.status.idle": "2024-10-16T13:08:01.265056Z", "shell.execute_reply": "2024-10-16T13:08:01.264440Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "851 ns ± 5.97 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit _flatten_reduce_with_init(t)" ] }, { "cell_type": "markdown", "id": "12b2a28a", "metadata": {}, "source": [ "### pytree.tree_reduce" ] }, { "cell_type": "code", "execution_count": 29, "id": "14e9237e", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:01.267427Z", "iopub.status.busy": "2024-10-16T13:08:01.266927Z", "iopub.status.idle": "2024-10-16T13:08:01.271682Z", "shell.execute_reply": "2024-10-16T13:08:01.271140Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_reduce\n", "\n", "tree_reduce(lambda x, y: x + y, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 30, "id": "f6c5cfda", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:01.273498Z", "iopub.status.busy": "2024-10-16T13:08:01.273295Z", "iopub.status.idle": "2024-10-16T13:08:02.938972Z", "shell.execute_reply": "2024-10-16T13:08:02.938244Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.05 µs ± 17.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 31, "id": "e79fb73b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:02.941139Z", "iopub.status.busy": "2024-10-16T13:08:02.940929Z", "iopub.status.idle": "2024-10-16T13:08:02.945349Z", "shell.execute_reply": "2024-10-16T13:08:02.944825Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)" ] }, { "cell_type": "code", "execution_count": 32, "id": "af47dd74", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:02.947448Z", "iopub.status.busy": "2024-10-16T13:08:02.947062Z", "iopub.status.idle": "2024-10-16T13:08:04.660393Z", "shell.execute_reply": "2024-10-16T13:08:04.659616Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.11 µs ± 14.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)" ] }, { "cell_type": "markdown", "id": "6e58b70f", "metadata": {}, "source": [ "## Structure Transpose" ] }, { "cell_type": "markdown", "id": "a4a31f9e", "metadata": {}, "source": [ "### Subside and Rise in TreeValue" ] }, { "cell_type": "code", "execution_count": 33, "id": "b6af689f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:04.662969Z", "iopub.status.busy": "2024-10-16T13:08:04.662516Z", "iopub.status.idle": "2024-10-16T13:08:04.669221Z", "shell.execute_reply": "2024-10-16T13:08:04.668539Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> {'a': 1, 'b': 10, 'c': {'x': 100, 'y': 400}}\n", "└── 'b' --> \n", " ├── 'x' --> {'a': 2, 'b': 20, 'c': {'x': 200, 'y': 500}}\n", " └── 'y' --> {'a': 3, 'b': 30, 'c': {'x': 300, 'y': 600}}" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import subside\n", "\n", "value = {\n", " 'a': FastTreeValue({'a': 1, 'b': {'x': 2, 'y': 3}}),\n", " 'b': FastTreeValue({'a': 10, 'b': {'x': 20, 'y': 30}}),\n", " 'c': {\n", " 'x': FastTreeValue({'a': 100, 'b': {'x': 200, 'y': 300}}),\n", " 'y': FastTreeValue({'a': 400, 'b': {'x': 500, 'y': 600}}),\n", " },\n", "}\n", "subside(value)" ] }, { "cell_type": "code", "execution_count": 34, "id": "b85caa84", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:04.671452Z", "iopub.status.busy": "2024-10-16T13:08:04.671041Z", "iopub.status.idle": "2024-10-16T13:08:13.212248Z", "shell.execute_reply": "2024-10-16T13:08:13.211555Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.5 µs ± 12.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit subside(value)" ] }, { "cell_type": "code", "execution_count": 35, "id": "f304d286", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:13.214777Z", "iopub.status.busy": "2024-10-16T13:08:13.214320Z", "iopub.status.idle": "2024-10-16T13:08:13.220425Z", "shell.execute_reply": "2024-10-16T13:08:13.219882Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': \n", " ├── 'a' --> 1\n", " ├── 'b' --> 10\n", " └── 'c' --> \n", " ├── 'x' --> 100\n", " └── 'y' --> 400,\n", " 'b': {'x': \n", " ├── 'a' --> 2\n", " ├── 'b' --> 20\n", " └── 'c' --> \n", " ├── 'x' --> 200\n", " └── 'y' --> 500,\n", " 'y': \n", " ├── 'a' --> 3\n", " ├── 'b' --> 30\n", " └── 'c' --> \n", " ├── 'x' --> 300\n", " └── 'y' --> 600}}" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import raw, rise\n", "\n", "value = FastTreeValue({\n", " 'a': raw({'a': 1, 'b': {'x': 2, 'y': 3}}),\n", " 'b': raw({'a': 10, 'b': {'x': 20, 'y': 30}}),\n", " 'c': {\n", " 'x': raw({'a': 100, 'b': {'x': 200, 'y': 300}}),\n", " 'y': raw({'a': 400, 'b': {'x': 500, 'y': 600}}),\n", " },\n", "})\n", "rise(value)" ] }, { "cell_type": "code", "execution_count": 36, "id": "62d34321", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:13.222544Z", "iopub.status.busy": "2024-10-16T13:08:13.222179Z", "iopub.status.idle": "2024-10-16T13:08:22.458820Z", "shell.execute_reply": "2024-10-16T13:08:22.458115Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11.4 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit rise(value)" ] }, { "cell_type": "code", "execution_count": 37, "id": "f0ed4793", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:22.461057Z", "iopub.status.busy": "2024-10-16T13:08:22.460686Z", "iopub.status.idle": "2024-10-16T13:08:22.465638Z", "shell.execute_reply": "2024-10-16T13:08:22.464988Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': \n", " ├── 'a' --> 1\n", " ├── 'b' --> 10\n", " └── 'c' --> \n", " ├── 'x' --> 100\n", " └── 'y' --> 400,\n", " 'b': {'x': \n", " ├── 'a' --> 2\n", " ├── 'b' --> 20\n", " └── 'c' --> \n", " ├── 'x' --> 200\n", " └── 'y' --> 500,\n", " 'y': \n", " ├── 'a' --> 3\n", " ├── 'b' --> 30\n", " └── 'c' --> \n", " ├── 'x' --> 300\n", " └── 'y' --> 600}}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vt = {'a': None, 'b': {'x': None, 'y': None}}\n", "rise(value, template=vt)" ] }, { "cell_type": "code", "execution_count": 38, "id": "a6ad3639", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:22.467934Z", "iopub.status.busy": "2024-10-16T13:08:22.467571Z", "iopub.status.idle": "2024-10-16T13:08:29.786706Z", "shell.execute_reply": "2024-10-16T13:08:29.786057Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.01 µs ± 38.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit rise(value, template=vt)" ] }, { "cell_type": "markdown", "id": "19ee26c6", "metadata": {}, "source": [ "### pytree.tree_transpose" ] }, { "cell_type": "code", "execution_count": 39, "id": "7f24f0f6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:29.789018Z", "iopub.status.busy": "2024-10-16T13:08:29.788618Z", "iopub.status.idle": "2024-10-16T13:08:29.795688Z", "shell.execute_reply": "2024-10-16T13:08:29.795151Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': {'a': 1, 'b': 10, 'c': {'x': 100, 'y': 400}},\n", " 'b': {'x': {'a': 2, 'b': 20, 'c': {'x': 200, 'y': 500}},\n", " 'y': {'a': 3, 'b': 30, 'c': {'x': 300, 'y': 600}}}}" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_structure, tree_transpose\n", "\n", "sto = tree_structure({'a': 1, 'b': 2, 'c': {'x': 3, 'y': 4}})\n", "sti = tree_structure({'a': 1, 'b': {'x': 2, 'y': 3}})\n", "\n", "value = (\n", " {'a': 1, 'b': {'x': 2, 'y': 3}},\n", " {\n", " 'a': {'a': 10, 'b': {'x': 20, 'y': 30}},\n", " 'b': [\n", " {'a': 100, 'b': {'x': 200, 'y': 300}},\n", " {'a': 400, 'b': {'x': 500, 'y': 600}},\n", " ],\n", " }\n", ")\n", "tree_transpose(sto, sti, value)" ] }, { "cell_type": "code", "execution_count": 40, "id": "d04072f3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:08:29.797762Z", "iopub.status.busy": "2024-10-16T13:08:29.797378Z", "iopub.status.idle": "2024-10-16T13:08:38.065226Z", "shell.execute_reply": "2024-10-16T13:08:38.064624Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.2 µs ± 46.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_transpose(sto, sti, value)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2a580a3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }