{ "cells": [ { "cell_type": "markdown", "id": "1a1cab6c", "metadata": {}, "source": [ "# Comparison Between TreeValue and Jax LibTree" ] }, { "cell_type": "markdown", "id": "a9426f1e", "metadata": {}, "source": [ "In this section, we will take a look at the feature and performance of the [jax-libtree](https://jax.readthedocs.io/en/latest/pytrees.html) library, which is developed by Google." ] }, { "cell_type": "code", "execution_count": 1, "id": "0f4b6d16", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:50.983586Z", "iopub.status.busy": "2024-10-16T13:18:50.983391Z", "iopub.status.idle": "2024-10-16T13:18:50.991139Z", "shell.execute_reply": "2024-10-16T13:18:50.990611Z" } }, "outputs": [], "source": [ "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}" ] }, { "cell_type": "markdown", "id": "b1085059", "metadata": {}, "source": [ "## Mapping Operation" ] }, { "cell_type": "markdown", "id": "4ded33f7", "metadata": {}, "source": [ "### TreeValue's Mapping" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa9abbed", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:50.993266Z", "iopub.status.busy": "2024-10-16T13:18:50.993069Z", "iopub.status.idle": "2024-10-16T13:18:51.360275Z", "shell.execute_reply": "2024-10-16T13:18:51.359578Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> 1\n", "├── 'b' --> 4\n", "└── 'x' --> \n", " ├── 'c' --> 9\n", " └── 'd' --> 16" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import mapping, FastTreeValue\n", "\n", "t = FastTreeValue(_TREE_DATA_1)\n", "mapping(t, lambda x: x ** 2)" ] }, { "cell_type": "code", "execution_count": 3, "id": "4dc3420b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:51.362602Z", "iopub.status.busy": "2024-10-16T13:18:51.362214Z", "iopub.status.idle": "2024-10-16T13:18:53.543142Z", "shell.execute_reply": "2024-10-16T13:18:53.542378Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.69 µs ± 8.48 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit mapping(t, lambda x: x ** 2)" ] }, { "cell_type": "code", "execution_count": 4, "id": "289e14c6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:53.545371Z", "iopub.status.busy": "2024-10-16T13:18:53.544958Z", "iopub.status.idle": "2024-10-16T13:18:53.549663Z", "shell.execute_reply": "2024-10-16T13:18:53.549128Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> (1, ('a',))\n", "├── 'b' --> (4, ('b',))\n", "└── 'x' --> \n", " ├── 'c' --> (9, ('x', 'c'))\n", " └── 'd' --> (16, ('x', 'd'))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mapping(t, lambda x, p: (x ** 2, p))" ] }, { "cell_type": "code", "execution_count": 5, "id": "fce10189", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:53.551886Z", "iopub.status.busy": "2024-10-16T13:18:53.551498Z", "iopub.status.idle": "2024-10-16T13:18:55.757519Z", "shell.execute_reply": "2024-10-16T13:18:55.756819Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.71 µs ± 19.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit mapping(t, lambda x, p: (x ** 2, p))" ] }, { "cell_type": "markdown", "id": "18d4346d", "metadata": {}, "source": [ "### pytree's tree_map" ] }, { "cell_type": "code", "execution_count": 6, "id": "91b8e706", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:55.760160Z", "iopub.status.busy": "2024-10-16T13:18:55.759720Z", "iopub.status.idle": "2024-10-16T13:18:55.764490Z", "shell.execute_reply": "2024-10-16T13:18:55.763948Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': 1, 'b': 4, 'x': {'c': 9, 'd': 16}}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_map\n", "\n", "tree_map(lambda x: x ** 2, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "41fa094b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:55.766603Z", "iopub.status.busy": "2024-10-16T13:18:55.766283Z", "iopub.status.idle": "2024-10-16T13:18:59.471999Z", "shell.execute_reply": "2024-10-16T13:18:59.471248Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.57 µs ± 28.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_map(lambda x: x ** 2, _TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "36191201", "metadata": {}, "source": [ "## Flatten and Unflatten Operation" ] }, { "cell_type": "markdown", "id": "a733f250", "metadata": {}, "source": [ "### TreeValue's Performance" ] }, { "cell_type": "code", "execution_count": 8, "id": "676f03c0", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:59.474296Z", "iopub.status.busy": "2024-10-16T13:18:59.474038Z", "iopub.status.idle": "2024-10-16T13:18:59.478783Z", "shell.execute_reply": "2024-10-16T13:18:59.478258Z" } }, "outputs": [ { "data": { "text/plain": [ "[(('a',), 1), (('b',), 2), (('x', 'c'), 3), (('x', 'd'), 4)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten, flatten_keys, flatten_values\n", "\n", "t_flatted = flatten(t)\n", "t_flatted" ] }, { "cell_type": "code", "execution_count": 9, "id": "7ca0363d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:59.480662Z", "iopub.status.busy": "2024-10-16T13:18:59.480454Z", "iopub.status.idle": "2024-10-16T13:19:03.583465Z", "shell.execute_reply": "2024-10-16T13:19:03.582798Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "505 ns ± 3.99 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten(t)" ] }, { "cell_type": "code", "execution_count": 10, "id": "62d3bd2d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:03.585544Z", "iopub.status.busy": "2024-10-16T13:19:03.585308Z", "iopub.status.idle": "2024-10-16T13:19:03.589907Z", "shell.execute_reply": "2024-10-16T13:19:03.589246Z" } }, "outputs": [ { "data": { "text/plain": [ "[('a',), ('b',), ('x', 'c'), ('x', 'd')]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten_keys\n", "\n", "flatten_keys(t)" ] }, { "cell_type": "code", "execution_count": 11, "id": "58fbf8c4", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:03.591942Z", "iopub.status.busy": "2024-10-16T13:19:03.591579Z", "iopub.status.idle": "2024-10-16T13:19:07.333660Z", "shell.execute_reply": "2024-10-16T13:19:07.332948Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "460 ns ± 1.68 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten_keys(t)" ] }, { "cell_type": "code", "execution_count": 12, "id": "1448e341", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:07.335825Z", "iopub.status.busy": "2024-10-16T13:19:07.335613Z", "iopub.status.idle": "2024-10-16T13:19:07.339914Z", "shell.execute_reply": "2024-10-16T13:19:07.339395Z" } }, "outputs": [ { "data": { "text/plain": [ "[1, 2, 3, 4]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import flatten_values\n", "\n", "flatten_values(t)" ] }, { "cell_type": "code", "execution_count": 13, "id": "956cf1c1", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:07.341799Z", "iopub.status.busy": "2024-10-16T13:19:07.341597Z", "iopub.status.idle": "2024-10-16T13:19:10.080095Z", "shell.execute_reply": "2024-10-16T13:19:10.079446Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "337 ns ± 3.06 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit flatten_values(t)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f35e3a6d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:10.082222Z", "iopub.status.busy": "2024-10-16T13:19:10.081958Z", "iopub.status.idle": "2024-10-16T13:19:10.086770Z", "shell.execute_reply": "2024-10-16T13:19:10.086207Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> 1\n", "├── 'b' --> 2\n", "└── 'x' --> \n", " ├── 'c' --> 3\n", " └── 'd' --> 4" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import unflatten\n", "\n", "unflatten(t_flatted)" ] }, { "cell_type": "code", "execution_count": 15, "id": "52980ab8", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:10.088857Z", "iopub.status.busy": "2024-10-16T13:19:10.088488Z", "iopub.status.idle": "2024-10-16T13:19:14.858362Z", "shell.execute_reply": "2024-10-16T13:19:14.857667Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "587 ns ± 1.75 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit unflatten(t_flatted)" ] }, { "cell_type": "markdown", "id": "0ff80a6c", "metadata": {}, "source": [ "### pytree's Performance" ] }, { "cell_type": "code", "execution_count": 16, "id": "e89fc565", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:14.860533Z", "iopub.status.busy": "2024-10-16T13:19:14.860326Z", "iopub.status.idle": "2024-10-16T13:19:14.864271Z", "shell.execute_reply": "2024-10-16T13:19:14.863617Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Leaves: [1, 2, 3, 4]\n", "Treedef: PyTreeDef({'a': *, 'b': *, 'x': {'c': *, 'd': *}})\n" ] } ], "source": [ "from jax.tree_util import tree_flatten\n", "\n", "leaves, treedef = tree_flatten(_TREE_DATA_1)\n", "print('Leaves:', leaves)\n", "print('Treedef:', treedef)" ] }, { "cell_type": "code", "execution_count": 17, "id": "5d59c4d3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:14.866225Z", "iopub.status.busy": "2024-10-16T13:19:14.866018Z", "iopub.status.idle": "2024-10-16T13:19:26.373843Z", "shell.execute_reply": "2024-10-16T13:19:26.373141Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.42 µs ± 11.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_flatten(_TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 18, "id": "bcb1318c", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:26.376229Z", "iopub.status.busy": "2024-10-16T13:19:26.375808Z", "iopub.status.idle": "2024-10-16T13:19:26.380234Z", "shell.execute_reply": "2024-10-16T13:19:26.379720Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_unflatten\n", "\n", "tree_unflatten(treedef, leaves)" ] }, { "cell_type": "code", "execution_count": 19, "id": "638b5144", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:26.382177Z", "iopub.status.busy": "2024-10-16T13:19:26.381821Z", "iopub.status.idle": "2024-10-16T13:19:31.930741Z", "shell.execute_reply": "2024-10-16T13:19:31.930085Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "683 ns ± 2.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_unflatten(treedef, leaves)" ] }, { "cell_type": "markdown", "id": "ffd1522b", "metadata": {}, "source": [ "## All Operation" ] }, { "cell_type": "markdown", "id": "112c91a1", "metadata": {}, "source": [ "### TreeValue's Performance" ] }, { "cell_type": "code", "execution_count": 20, "id": "1b9c1c51", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:31.933243Z", "iopub.status.busy": "2024-10-16T13:19:31.932794Z", "iopub.status.idle": "2024-10-16T13:19:31.937346Z", "shell.execute_reply": "2024-10-16T13:19:31.936793Z" } }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all(flatten_values(t))" ] }, { "cell_type": "code", "execution_count": 21, "id": "fb8b0b2d", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:31.939461Z", "iopub.status.busy": "2024-10-16T13:19:31.939062Z", "iopub.status.idle": "2024-10-16T13:19:35.527018Z", "shell.execute_reply": "2024-10-16T13:19:35.526338Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "441 ns ± 2.85 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit all(flatten_values(t))" ] }, { "cell_type": "markdown", "id": "86462abd", "metadata": {}, "source": [ "### pytree.tree_all's performance" ] }, { "cell_type": "code", "execution_count": 22, "id": "64d39484", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:35.529312Z", "iopub.status.busy": "2024-10-16T13:19:35.528933Z", "iopub.status.idle": "2024-10-16T13:19:35.532116Z", "shell.execute_reply": "2024-10-16T13:19:35.531475Z" } }, "outputs": [], "source": [ "from jax.tree_util import tree_all" ] }, { "cell_type": "code", "execution_count": 23, "id": "6a8427db", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:35.534376Z", "iopub.status.busy": "2024-10-16T13:19:35.534004Z", "iopub.status.idle": "2024-10-16T13:19:35.538271Z", "shell.execute_reply": "2024-10-16T13:19:35.537610Z" } }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_all(_TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 24, "id": "e4a4cd59", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:35.540303Z", "iopub.status.busy": "2024-10-16T13:19:35.539951Z", "iopub.status.idle": "2024-10-16T13:19:48.507770Z", "shell.execute_reply": "2024-10-16T13:19:48.507092Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.6 µs ± 6.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit tree_all(_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "f0221182", "metadata": {}, "source": [ "## Reduce Operation" ] }, { "cell_type": "markdown", "id": "646b4a2e", "metadata": {}, "source": [ "### TreeValue's Reduce" ] }, { "cell_type": "code", "execution_count": 25, "id": "105e9001", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:48.510130Z", "iopub.status.busy": "2024-10-16T13:19:48.509705Z", "iopub.status.idle": "2024-10-16T13:19:48.514695Z", "shell.execute_reply": "2024-10-16T13:19:48.514020Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functools import reduce\n", "\n", "def _flatten_reduce(tree):\n", " values = flatten_values(tree)\n", " return reduce(lambda x, y: x + y, values)\n", "\n", "_flatten_reduce(t)" ] }, { "cell_type": "code", "execution_count": 26, "id": "71c145c0", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:48.516744Z", "iopub.status.busy": "2024-10-16T13:19:48.516382Z", "iopub.status.idle": "2024-10-16T13:19:55.052187Z", "shell.execute_reply": "2024-10-16T13:19:55.051435Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "804 ns ± 6.39 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit _flatten_reduce(t)" ] }, { "cell_type": "code", "execution_count": 27, "id": "20b34964", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:55.054695Z", "iopub.status.busy": "2024-10-16T13:19:55.054146Z", "iopub.status.idle": "2024-10-16T13:19:55.059174Z", "shell.execute_reply": "2024-10-16T13:19:55.058536Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def _flatten_reduce_with_init(tree):\n", " values = flatten_values(tree)\n", " return reduce(lambda x, y: x + y, values, 0)\n", "\n", "_flatten_reduce_with_init(t)" ] }, { "cell_type": "code", "execution_count": 28, "id": "ca8562c5", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:19:55.061353Z", "iopub.status.busy": "2024-10-16T13:19:55.060991Z", "iopub.status.idle": "2024-10-16T13:20:02.108986Z", "shell.execute_reply": "2024-10-16T13:20:02.108241Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "868 ns ± 6.33 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit _flatten_reduce_with_init(t)" ] }, { "cell_type": "markdown", "id": "12b2a28a", "metadata": {}, "source": [ "### pytree.tree_reduce" ] }, { "cell_type": "code", "execution_count": 29, "id": "14e9237e", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:02.111401Z", "iopub.status.busy": "2024-10-16T13:20:02.110987Z", "iopub.status.idle": "2024-10-16T13:20:02.115706Z", "shell.execute_reply": "2024-10-16T13:20:02.115063Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_reduce\n", "\n", "tree_reduce(lambda x, y: x + y, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 30, "id": "f6c5cfda", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:02.117686Z", "iopub.status.busy": "2024-10-16T13:20:02.117318Z", "iopub.status.idle": "2024-10-16T13:20:03.779615Z", "shell.execute_reply": "2024-10-16T13:20:03.778904Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.04 µs ± 50.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1)" ] }, { "cell_type": "code", "execution_count": 31, "id": "e79fb73b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:03.781917Z", "iopub.status.busy": "2024-10-16T13:20:03.781447Z", "iopub.status.idle": "2024-10-16T13:20:03.786269Z", "shell.execute_reply": "2024-10-16T13:20:03.785617Z" } }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)" ] }, { "cell_type": "code", "execution_count": 32, "id": "af47dd74", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:03.788398Z", "iopub.status.busy": "2024-10-16T13:20:03.788000Z", "iopub.status.idle": "2024-10-16T13:20:05.542797Z", "shell.execute_reply": "2024-10-16T13:20:05.542048Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.15 µs ± 15.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_reduce(lambda x, y: x + y, _TREE_DATA_1, 0)" ] }, { "cell_type": "markdown", "id": "6e58b70f", "metadata": {}, "source": [ "## Structure Transpose" ] }, { "cell_type": "markdown", "id": "a4a31f9e", "metadata": {}, "source": [ "### Subside and Rise in TreeValue" ] }, { "cell_type": "code", "execution_count": 33, "id": "b6af689f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:05.545115Z", "iopub.status.busy": "2024-10-16T13:20:05.544778Z", "iopub.status.idle": "2024-10-16T13:20:05.551239Z", "shell.execute_reply": "2024-10-16T13:20:05.550682Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> {'a': 1, 'b': 10, 'c': {'x': 100, 'y': 400}}\n", "└── 'b' --> \n", " ├── 'x' --> {'a': 2, 'b': 20, 'c': {'x': 200, 'y': 500}}\n", " └── 'y' --> {'a': 3, 'b': 30, 'c': {'x': 300, 'y': 600}}" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import subside\n", "\n", "value = {\n", " 'a': FastTreeValue({'a': 1, 'b': {'x': 2, 'y': 3}}),\n", " 'b': FastTreeValue({'a': 10, 'b': {'x': 20, 'y': 30}}),\n", " 'c': {\n", " 'x': FastTreeValue({'a': 100, 'b': {'x': 200, 'y': 300}}),\n", " 'y': FastTreeValue({'a': 400, 'b': {'x': 500, 'y': 600}}),\n", " },\n", "}\n", "subside(value)" ] }, { "cell_type": "code", "execution_count": 34, "id": "b85caa84", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:05.553235Z", "iopub.status.busy": "2024-10-16T13:20:05.552873Z", "iopub.status.idle": "2024-10-16T13:20:14.083642Z", "shell.execute_reply": "2024-10-16T13:20:14.082996Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.5 µs ± 50.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit subside(value)" ] }, { "cell_type": "code", "execution_count": 35, "id": "f304d286", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:14.085891Z", "iopub.status.busy": "2024-10-16T13:20:14.085500Z", "iopub.status.idle": "2024-10-16T13:20:14.091637Z", "shell.execute_reply": "2024-10-16T13:20:14.091095Z" } }, "outputs": [ { "data": { "text/plain": [ "{'b': {'x': \n", " ├── 'a' --> 2\n", " ├── 'b' --> 20\n", " └── 'c' --> \n", " ├── 'x' --> 200\n", " └── 'y' --> 500,\n", " 'y': \n", " ├── 'a' --> 3\n", " ├── 'b' --> 30\n", " └── 'c' --> \n", " ├── 'x' --> 300\n", " └── 'y' --> 600},\n", " 'a': \n", " ├── 'a' --> 1\n", " ├── 'b' --> 10\n", " └── 'c' --> \n", " ├── 'x' --> 100\n", " └── 'y' --> 400}" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from treevalue import raw, rise\n", "\n", "value = FastTreeValue({\n", " 'a': raw({'a': 1, 'b': {'x': 2, 'y': 3}}),\n", " 'b': raw({'a': 10, 'b': {'x': 20, 'y': 30}}),\n", " 'c': {\n", " 'x': raw({'a': 100, 'b': {'x': 200, 'y': 300}}),\n", " 'y': raw({'a': 400, 'b': {'x': 500, 'y': 600}}),\n", " },\n", "})\n", "rise(value)" ] }, { "cell_type": "code", "execution_count": 36, "id": "62d34321", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:14.093716Z", "iopub.status.busy": "2024-10-16T13:20:14.093323Z", "iopub.status.idle": "2024-10-16T13:20:23.516485Z", "shell.execute_reply": "2024-10-16T13:20:23.515800Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11.6 µs ± 54.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit rise(value)" ] }, { "cell_type": "code", "execution_count": 37, "id": "f0ed4793", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:23.518831Z", "iopub.status.busy": "2024-10-16T13:20:23.518421Z", "iopub.status.idle": "2024-10-16T13:20:23.523527Z", "shell.execute_reply": "2024-10-16T13:20:23.522999Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': \n", " ├── 'a' --> 1\n", " ├── 'b' --> 10\n", " └── 'c' --> \n", " ├── 'x' --> 100\n", " └── 'y' --> 400,\n", " 'b': {'x': \n", " ├── 'a' --> 2\n", " ├── 'b' --> 20\n", " └── 'c' --> \n", " ├── 'x' --> 200\n", " └── 'y' --> 500,\n", " 'y': \n", " ├── 'a' --> 3\n", " ├── 'b' --> 30\n", " └── 'c' --> \n", " ├── 'x' --> 300\n", " └── 'y' --> 600}}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vt = {'a': None, 'b': {'x': None, 'y': None}}\n", "rise(value, template=vt)" ] }, { "cell_type": "code", "execution_count": 38, "id": "a6ad3639", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:23.525685Z", "iopub.status.busy": "2024-10-16T13:20:23.525304Z", "iopub.status.idle": "2024-10-16T13:20:31.022052Z", "shell.execute_reply": "2024-10-16T13:20:31.021361Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.23 µs ± 83 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit rise(value, template=vt)" ] }, { "cell_type": "markdown", "id": "19ee26c6", "metadata": {}, "source": [ "### pytree.tree_transpose" ] }, { "cell_type": "code", "execution_count": 39, "id": "7f24f0f6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:31.024450Z", "iopub.status.busy": "2024-10-16T13:20:31.023994Z", "iopub.status.idle": "2024-10-16T13:20:31.030872Z", "shell.execute_reply": "2024-10-16T13:20:31.030356Z" } }, "outputs": [ { "data": { "text/plain": [ "{'a': {'a': 1, 'b': 10, 'c': {'x': 100, 'y': 400}},\n", " 'b': {'x': {'a': 2, 'b': 20, 'c': {'x': 200, 'y': 500}},\n", " 'y': {'a': 3, 'b': 30, 'c': {'x': 300, 'y': 600}}}}" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.tree_util import tree_structure, tree_transpose\n", "\n", "sto = tree_structure({'a': 1, 'b': 2, 'c': {'x': 3, 'y': 4}})\n", "sti = tree_structure({'a': 1, 'b': {'x': 2, 'y': 3}})\n", "\n", "value = (\n", " {'a': 1, 'b': {'x': 2, 'y': 3}},\n", " {\n", " 'a': {'a': 10, 'b': {'x': 20, 'y': 30}},\n", " 'b': [\n", " {'a': 100, 'b': {'x': 200, 'y': 300}},\n", " {'a': 400, 'b': {'x': 500, 'y': 600}},\n", " ],\n", " }\n", ")\n", "tree_transpose(sto, sti, value)" ] }, { "cell_type": "code", "execution_count": 40, "id": "d04072f3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:20:31.032887Z", "iopub.status.busy": "2024-10-16T13:20:31.032473Z", "iopub.status.idle": "2024-10-16T13:20:39.153063Z", "shell.execute_reply": "2024-10-16T13:20:39.152351Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 µs ± 22 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit tree_transpose(sto, sti, value)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2a580a3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }