{ "cells": [ { "cell_type": "markdown", "id": "9d39d946", "metadata": {}, "source": [ "# Comparison Between TreeValue and Tianshou Batch" ] }, { "cell_type": "markdown", "id": "3c6db2d4", "metadata": {}, "source": [ "In this section, we will take a look at the feature and performance of the [Tianshou Batch](https://github.com/thu-ml/tianshou) library, which is developed by Tsinghua Machine Learning Group." ] }, { "cell_type": "markdown", "id": "069361b0", "metadata": {}, "source": [ "Before starting the comparison, let us define some thing." ] }, { "cell_type": "code", "execution_count": 1, "id": "06fc8d26", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:44.988480Z", "iopub.status.busy": "2024-10-16T13:25:44.987980Z", "iopub.status.idle": "2024-10-16T13:25:46.185676Z", "shell.execute_reply": "2024-10-16T13:25:46.185004Z" } }, "outputs": [], "source": [ "import torch \n", "\n", "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}\n", "_TREE_DATA_2 = {\n", " 'a': torch.randn(2, 3), \n", " 'x': {\n", " 'c': torch.randn(3, 4)\n", " },\n", "}\n", "_TREE_DATA_3 = {\n", " 'obs': torch.randn(4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(1,)),\n", " 'reward': torch.rand(1),\n", "}" ] }, { "cell_type": "markdown", "id": "83461b25", "metadata": {}, "source": [ "## Read and Write Operation" ] }, { "cell_type": "markdown", "id": "067b3f73", "metadata": {}, "source": [ "Reading and writing are the two most common operations in the tree data structure based on the data model (TreeValue and Tianshou Batch both belong to this type), so this section will compare the reading and writing performance of these two libraries." ] }, { "cell_type": "markdown", "id": "5d09a5b7", "metadata": {}, "source": [ "### TreeValue's Get and Set" ] }, { "cell_type": "code", "execution_count": 2, "id": "9519c4bb", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:46.188538Z", "iopub.status.busy": "2024-10-16T13:25:46.188038Z", "iopub.status.idle": "2024-10-16T13:25:46.506849Z", "shell.execute_reply": "2024-10-16T13:25:46.506108Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/tmp4q2ykbbk/fbb7ad1854f8455286e2cedd40b3312212c7f17a/treevalue/tree/integration/torch.py:18: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n", " register_for_torch(TreeValue)\n", "/tmp/tmp4q2ykbbk/fbb7ad1854f8455286e2cedd40b3312212c7f17a/treevalue/tree/integration/torch.py:19: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n", " register_for_torch(FastTreeValue)\n" ] } ], "source": [ "from treevalue import FastTreeValue\n", "\n", "t = FastTreeValue(_TREE_DATA_2)" ] }, { "cell_type": "code", "execution_count": 3, "id": "11c37677", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:46.509071Z", "iopub.status.busy": "2024-10-16T13:25:46.508610Z", "iopub.status.idle": "2024-10-16T13:25:46.517233Z", "shell.execute_reply": "2024-10-16T13:25:46.516590Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t" ] }, { "cell_type": "code", "execution_count": 4, "id": "fd70b0b9", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:46.519345Z", "iopub.status.busy": "2024-10-16T13:25:46.518982Z", "iopub.status.idle": "2024-10-16T13:25:46.523568Z", "shell.execute_reply": "2024-10-16T13:25:46.522935Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.a" ] }, { "cell_type": "code", "execution_count": 5, "id": "c18197bd", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:46.525586Z", "iopub.status.busy": "2024-10-16T13:25:46.525239Z", "iopub.status.idle": "2024-10-16T13:25:50.472740Z", "shell.execute_reply": "2024-10-16T13:25:50.472011Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "48.7 ns ± 0.673 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit t.a" ] }, { "cell_type": "code", "execution_count": 6, "id": "bd52f867", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:50.474925Z", "iopub.status.busy": "2024-10-16T13:25:50.474692Z", "iopub.status.idle": "2024-10-16T13:25:50.480577Z", "shell.execute_reply": "2024-10-16T13:25:50.479941Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[-0.1345, -1.2116, 1.2098],\n", "│ [ 1.2173, -1.8312, -0.2737]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_value = torch.randn(2, 3)\n", "t.a = new_value\n", "\n", "t" ] }, { "cell_type": "code", "execution_count": 7, "id": "bbe04d1c", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:50.482681Z", "iopub.status.busy": "2024-10-16T13:25:50.482315Z", "iopub.status.idle": "2024-10-16T13:25:55.089357Z", "shell.execute_reply": "2024-10-16T13:25:55.088623Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "56.7 ns ± 0.0667 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit t.a = new_value" ] }, { "cell_type": "markdown", "id": "48c49731", "metadata": {}, "source": [ "### Tianshou Batch's Get and Set" ] }, { "cell_type": "code", "execution_count": 8, "id": "f1bb14c1", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:55.091543Z", "iopub.status.busy": "2024-10-16T13:25:55.091313Z", "iopub.status.idle": "2024-10-16T13:25:55.408055Z", "shell.execute_reply": "2024-10-16T13:25:55.407398Z" } }, "outputs": [], "source": [ "from tianshou.data import Batch\n", "\n", "b = Batch(**_TREE_DATA_2)" ] }, { "cell_type": "code", "execution_count": 9, "id": "cb0777c3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:55.410595Z", "iopub.status.busy": "2024-10-16T13:25:55.410134Z", "iopub.status.idle": "2024-10-16T13:25:55.415489Z", "shell.execute_reply": "2024-10-16T13:25:55.414852Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]]),\n", " x: Batch(\n", " c: tensor([[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]]),\n", " ),\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b" ] }, { "cell_type": "code", "execution_count": 10, "id": "43ef8ea3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:55.417712Z", "iopub.status.busy": "2024-10-16T13:25:55.417356Z", "iopub.status.idle": "2024-10-16T13:25:55.421826Z", "shell.execute_reply": "2024-10-16T13:25:55.421299Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.a" ] }, { "cell_type": "code", "execution_count": 11, "id": "b785ab72", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:55.423767Z", "iopub.status.busy": "2024-10-16T13:25:55.423566Z", "iopub.status.idle": "2024-10-16T13:25:58.751055Z", "shell.execute_reply": "2024-10-16T13:25:58.750296Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "41 ns ± 0.362 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit b.a" ] }, { "cell_type": "code", "execution_count": 12, "id": "ad54dc69", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:58.753501Z", "iopub.status.busy": "2024-10-16T13:25:58.753000Z", "iopub.status.idle": "2024-10-16T13:25:58.758849Z", "shell.execute_reply": "2024-10-16T13:25:58.758191Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[-1.8601, 2.1694, -1.1358],\n", " [ 0.0145, -1.6993, 0.2390]]),\n", " x: Batch(\n", " c: tensor([[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]]),\n", " ),\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_value = torch.randn(2, 3)\n", "b.a = new_value\n", "\n", "b" ] }, { "cell_type": "code", "execution_count": 13, "id": "29b1d0bf", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:25:58.761084Z", "iopub.status.busy": "2024-10-16T13:25:58.760711Z", "iopub.status.idle": "2024-10-16T13:26:01.747588Z", "shell.execute_reply": "2024-10-16T13:26:01.746805Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "367 ns ± 0.489 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit b.a = new_value" ] }, { "cell_type": "markdown", "id": "b61ad1d0", "metadata": {}, "source": [ "## Initialization" ] }, { "cell_type": "markdown", "id": "d70f0d54", "metadata": {}, "source": [ "### TreeValue's Initialization" ] }, { "cell_type": "code", "execution_count": 14, "id": "d32a679b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:01.750668Z", "iopub.status.busy": "2024-10-16T13:26:01.749969Z", "iopub.status.idle": "2024-10-16T13:26:06.820705Z", "shell.execute_reply": "2024-10-16T13:26:06.819857Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "625 ns ± 6.03 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit FastTreeValue(_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "24f3707b", "metadata": {}, "source": [ "### Tianshou Batch's Initialization" ] }, { "cell_type": "code", "execution_count": 15, "id": "ac3958df", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:06.823734Z", "iopub.status.busy": "2024-10-16T13:26:06.823125Z", "iopub.status.idle": "2024-10-16T13:26:13.842563Z", "shell.execute_reply": "2024-10-16T13:26:13.841923Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.68 µs ± 15.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit Batch(**_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "1ab82e2d", "metadata": {}, "source": [ "## Deep Copy Operation" ] }, { "cell_type": "code", "execution_count": 16, "id": "210a9442", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:13.844795Z", "iopub.status.busy": "2024-10-16T13:26:13.844383Z", "iopub.status.idle": "2024-10-16T13:26:13.847539Z", "shell.execute_reply": "2024-10-16T13:26:13.846864Z" } }, "outputs": [], "source": [ "import copy" ] }, { "cell_type": "markdown", "id": "5a736274", "metadata": {}, "source": [ "### Deep Copy of TreeValue" ] }, { "cell_type": "code", "execution_count": 17, "id": "f9bcadd6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:13.849630Z", "iopub.status.busy": "2024-10-16T13:26:13.849259Z", "iopub.status.idle": "2024-10-16T13:26:24.247417Z", "shell.execute_reply": "2024-10-16T13:26:24.246680Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "128 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "t3 = FastTreeValue(_TREE_DATA_3)\n", "%timeit copy.deepcopy(t3)" ] }, { "cell_type": "markdown", "id": "bf8be7ea", "metadata": {}, "source": [ "### Deep Copy of Tianshou Batch" ] }, { "cell_type": "code", "execution_count": 18, "id": "91998e6f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:24.250151Z", "iopub.status.busy": "2024-10-16T13:26:24.249683Z", "iopub.status.idle": "2024-10-16T13:26:34.560314Z", "shell.execute_reply": "2024-10-16T13:26:34.559612Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "127 µs ± 526 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "b3 = Batch(**_TREE_DATA_3)\n", "%timeit copy.deepcopy(b3)" ] }, { "cell_type": "markdown", "id": "223162fb", "metadata": {}, "source": [ "## Stack, Concat and Split Operation" ] }, { "cell_type": "markdown", "id": "85fa4a73", "metadata": {}, "source": [ "### Performance of TreeValue" ] }, { "cell_type": "code", "execution_count": 19, "id": "a0c2b697", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:34.562655Z", "iopub.status.busy": "2024-10-16T13:26:34.562267Z", "iopub.status.idle": "2024-10-16T13:26:34.565514Z", "shell.execute_reply": "2024-10-16T13:26:34.564952Z" } }, "outputs": [], "source": [ "trees = [FastTreeValue(_TREE_DATA_2) for _ in range(8)]" ] }, { "cell_type": "code", "execution_count": 20, "id": "017ea5a5", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:34.567442Z", "iopub.status.busy": "2024-10-16T13:26:34.567115Z", "iopub.status.idle": "2024-10-16T13:26:34.573940Z", "shell.execute_reply": "2024-10-16T13:26:34.573316Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]],\n", "│ \n", "│ [[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]],\n", "│ \n", "│ [[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]],\n", "│ \n", "│ [[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]],\n", "│ \n", "│ [[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]],\n", "│ \n", "│ [[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]],\n", "│ \n", "│ [[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]],\n", "│ \n", "│ [[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_stack = FastTreeValue.func(subside=True)(torch.stack)\n", "\n", "t_stack(trees)" ] }, { "cell_type": "code", "execution_count": 21, "id": "f8b3f415", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:34.575920Z", "iopub.status.busy": "2024-10-16T13:26:34.575555Z", "iopub.status.idle": "2024-10-16T13:26:36.514543Z", "shell.execute_reply": "2024-10-16T13:26:36.513796Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "23.8 µs ± 231 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit t_stack(trees)" ] }, { "cell_type": "code", "execution_count": 22, "id": "94b56771", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:36.516990Z", "iopub.status.busy": "2024-10-16T13:26:36.516588Z", "iopub.status.idle": "2024-10-16T13:26:36.523497Z", "shell.execute_reply": "2024-10-16T13:26:36.522855Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482],\n", "│ [ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482],\n", "│ [ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482],\n", "│ [ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482],\n", "│ [ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482],\n", "│ [ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482],\n", "│ [ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482],\n", "│ [ 0.5140, 2.0014, 0.4726],\n", "│ [-0.3070, 1.5447, 0.5482]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_cat = FastTreeValue.func(subside=True)(torch.cat)\n", "\n", "t_cat(trees)" ] }, { "cell_type": "code", "execution_count": 23, "id": "5e9c06a6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:36.525649Z", "iopub.status.busy": "2024-10-16T13:26:36.525295Z", "iopub.status.idle": "2024-10-16T13:26:38.329588Z", "shell.execute_reply": "2024-10-16T13:26:38.328877Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "22.2 µs ± 612 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit t_cat(trees)" ] }, { "cell_type": "code", "execution_count": 24, "id": "a3ab5c8f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:38.331674Z", "iopub.status.busy": "2024-10-16T13:26:38.331459Z", "iopub.status.idle": "2024-10-16T13:26:42.442320Z", "shell.execute_reply": "2024-10-16T13:26:42.441644Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50.5 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "t_split = FastTreeValue.func(rise=True)(torch.split)\n", "tree = FastTreeValue({\n", " 'obs': torch.randn(8, 4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(8, 1,)),\n", " 'reward': torch.rand(8, 1),\n", "})\n", "\n", "%timeit t_split(tree, 1)" ] }, { "cell_type": "markdown", "id": "31c3ec0b", "metadata": {}, "source": [ "### Performance of Tianshou Batch" ] }, { "cell_type": "code", "execution_count": 25, "id": "9ead828a", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:42.444912Z", "iopub.status.busy": "2024-10-16T13:26:42.444389Z", "iopub.status.idle": "2024-10-16T13:26:42.451457Z", "shell.execute_reply": "2024-10-16T13:26:42.450793Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " x: Batch(\n", " c: tensor([[[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]],\n", " \n", " [[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]]]),\n", " ),\n", " a: tensor([[[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]],\n", " \n", " [[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]],\n", " \n", " [[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]],\n", " \n", " [[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]],\n", " \n", " [[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]],\n", " \n", " [[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]],\n", " \n", " [[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]],\n", " \n", " [[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]]]),\n", ")" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batches = [Batch(**_TREE_DATA_2) for _ in range(8)]\n", "\n", "Batch.stack(batches)" ] }, { "cell_type": "code", "execution_count": 26, "id": "ec9037a3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:42.453675Z", "iopub.status.busy": "2024-10-16T13:26:42.453210Z", "iopub.status.idle": "2024-10-16T13:26:47.620978Z", "shell.execute_reply": "2024-10-16T13:26:47.620225Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "63.6 µs ± 924 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit Batch.stack(batches)" ] }, { "cell_type": "code", "execution_count": 27, "id": "cb8ab77e", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:47.623078Z", "iopub.status.busy": "2024-10-16T13:26:47.622854Z", "iopub.status.idle": "2024-10-16T13:26:47.629724Z", "shell.execute_reply": "2024-10-16T13:26:47.629068Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " x: Batch(\n", " c: tensor([[ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524],\n", " [ 1.8188, 1.8288, 0.6108, 0.7464],\n", " [-0.8200, -0.1059, 0.4295, 1.1611],\n", " [ 0.0272, -0.6516, -0.2357, 0.9524]]),\n", " ),\n", " a: tensor([[ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482],\n", " [ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482],\n", " [ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482],\n", " [ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482],\n", " [ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482],\n", " [ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482],\n", " [ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482],\n", " [ 0.5140, 2.0014, 0.4726],\n", " [-0.3070, 1.5447, 0.5482]]),\n", ")" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Batch.cat(batches)" ] }, { "cell_type": "code", "execution_count": 28, "id": "18dfb045", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:47.631946Z", "iopub.status.busy": "2024-10-16T13:26:47.631578Z", "iopub.status.idle": "2024-10-16T13:26:57.148600Z", "shell.execute_reply": "2024-10-16T13:26:57.147947Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "117 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit Batch.cat(batches)" ] }, { "cell_type": "code", "execution_count": 29, "id": "c6688e51", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:26:57.150985Z", "iopub.status.busy": "2024-10-16T13:26:57.150496Z", "iopub.status.idle": "2024-10-16T13:26:59.427604Z", "shell.execute_reply": "2024-10-16T13:26:59.426859Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "280 µs ± 2.08 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "batch = Batch({\n", " 'obs': torch.randn(8, 4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(8, 1,)),\n", " 'reward': torch.rand(8, 1)}\n", ")\n", "\n", "%timeit list(Batch.split(batch, 1, shuffle=False, merge_last=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "2539fbd9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }