{ "cells": [ { "cell_type": "markdown", "id": "9d39d946", "metadata": {}, "source": [ "# Comparison Between TreeValue and Tianshou Batch" ] }, { "cell_type": "markdown", "id": "3c6db2d4", "metadata": {}, "source": [ "In this section, we will take a look at the feature and performance of the [Tianshou Batch](https://github.com/thu-ml/tianshou) library, which is developed by Tsinghua Machine Learning Group." ] }, { "cell_type": "markdown", "id": "069361b0", "metadata": {}, "source": [ "Before starting the comparison, let us define some thing." ] }, { "cell_type": "code", "execution_count": 1, "id": "06fc8d26", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:29.834134Z", "iopub.status.busy": "2024-10-16T13:05:29.833886Z", "iopub.status.idle": "2024-10-16T13:05:31.045289Z", "shell.execute_reply": "2024-10-16T13:05:31.044559Z" } }, "outputs": [], "source": [ "import torch \n", "\n", "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}\n", "_TREE_DATA_2 = {\n", " 'a': torch.randn(2, 3), \n", " 'x': {\n", " 'c': torch.randn(3, 4)\n", " },\n", "}\n", "_TREE_DATA_3 = {\n", " 'obs': torch.randn(4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(1,)),\n", " 'reward': torch.rand(1),\n", "}" ] }, { "cell_type": "markdown", "id": "83461b25", "metadata": {}, "source": [ "## Read and Write Operation" ] }, { "cell_type": "markdown", "id": "067b3f73", "metadata": {}, "source": [ "Reading and writing are the two most common operations in the tree data structure based on the data model (TreeValue and Tianshou Batch both belong to this type), so this section will compare the reading and writing performance of these two libraries." ] }, { "cell_type": "markdown", "id": "5d09a5b7", "metadata": {}, "source": [ "### TreeValue's Get and Set" ] }, { "cell_type": "code", "execution_count": 2, "id": "9519c4bb", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:31.048303Z", "iopub.status.busy": "2024-10-16T13:05:31.047863Z", "iopub.status.idle": "2024-10-16T13:05:31.074152Z", "shell.execute_reply": "2024-10-16T13:05:31.073578Z" } }, "outputs": [], "source": [ "from treevalue import FastTreeValue\n", "\n", "t = FastTreeValue(_TREE_DATA_2)" ] }, { "cell_type": "code", "execution_count": 3, "id": "11c37677", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:31.076169Z", "iopub.status.busy": "2024-10-16T13:05:31.075816Z", "iopub.status.idle": "2024-10-16T13:05:31.084444Z", "shell.execute_reply": "2024-10-16T13:05:31.083814Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t" ] }, { "cell_type": "code", "execution_count": 4, "id": "fd70b0b9", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:31.086566Z", "iopub.status.busy": "2024-10-16T13:05:31.086206Z", "iopub.status.idle": "2024-10-16T13:05:31.090894Z", "shell.execute_reply": "2024-10-16T13:05:31.090270Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.a" ] }, { "cell_type": "code", "execution_count": 5, "id": "c18197bd", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:31.093135Z", "iopub.status.busy": "2024-10-16T13:05:31.092775Z", "iopub.status.idle": "2024-10-16T13:05:35.111086Z", "shell.execute_reply": "2024-10-16T13:05:35.110360Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "49.4 ns ± 0.39 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit t.a" ] }, { "cell_type": "code", "execution_count": 6, "id": "bd52f867", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:35.113465Z", "iopub.status.busy": "2024-10-16T13:05:35.112943Z", "iopub.status.idle": "2024-10-16T13:05:35.119069Z", "shell.execute_reply": "2024-10-16T13:05:35.118394Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[-0.5772, 0.2319, -1.2415],\n", "│ [-1.3844, 0.1663, -0.6257]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_value = torch.randn(2, 3)\n", "t.a = new_value\n", "\n", "t" ] }, { "cell_type": "code", "execution_count": 7, "id": "bbe04d1c", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:35.121104Z", "iopub.status.busy": "2024-10-16T13:05:35.120696Z", "iopub.status.idle": "2024-10-16T13:05:39.500511Z", "shell.execute_reply": "2024-10-16T13:05:39.499817Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "53.4 ns ± 0.121 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit t.a = new_value" ] }, { "cell_type": "markdown", "id": "48c49731", "metadata": {}, "source": [ "### Tianshou Batch's Get and Set" ] }, { "cell_type": "code", "execution_count": 8, "id": "f1bb14c1", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:39.503060Z", "iopub.status.busy": "2024-10-16T13:05:39.502851Z", "iopub.status.idle": "2024-10-16T13:05:39.802752Z", "shell.execute_reply": "2024-10-16T13:05:39.802011Z" } }, "outputs": [], "source": [ "from tianshou.data import Batch\n", "\n", "b = Batch(**_TREE_DATA_2)" ] }, { "cell_type": "code", "execution_count": 9, "id": "cb0777c3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:39.805398Z", "iopub.status.busy": "2024-10-16T13:05:39.805107Z", "iopub.status.idle": "2024-10-16T13:05:39.810523Z", "shell.execute_reply": "2024-10-16T13:05:39.809829Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]]),\n", " x: Batch(\n", " c: tensor([[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]]),\n", " ),\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b" ] }, { "cell_type": "code", "execution_count": 10, "id": "43ef8ea3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:39.812476Z", "iopub.status.busy": "2024-10-16T13:05:39.812268Z", "iopub.status.idle": "2024-10-16T13:05:39.816709Z", "shell.execute_reply": "2024-10-16T13:05:39.816182Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.a" ] }, { "cell_type": "code", "execution_count": 11, "id": "b785ab72", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:39.818552Z", "iopub.status.busy": "2024-10-16T13:05:39.818350Z", "iopub.status.idle": "2024-10-16T13:05:43.153866Z", "shell.execute_reply": "2024-10-16T13:05:43.153202Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "41.1 ns ± 0.336 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit b.a" ] }, { "cell_type": "code", "execution_count": 12, "id": "ad54dc69", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:43.155910Z", "iopub.status.busy": "2024-10-16T13:05:43.155687Z", "iopub.status.idle": "2024-10-16T13:05:43.161402Z", "shell.execute_reply": "2024-10-16T13:05:43.160865Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[-6.5475e-01, -4.4569e-01, 5.2007e-01],\n", " [-1.1818e+00, 1.8087e-04, 9.2330e-01]]),\n", " x: Batch(\n", " c: tensor([[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]]),\n", " ),\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_value = torch.randn(2, 3)\n", "b.a = new_value\n", "\n", "b" ] }, { "cell_type": "code", "execution_count": 13, "id": "29b1d0bf", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:43.163303Z", "iopub.status.busy": "2024-10-16T13:05:43.163093Z", "iopub.status.idle": "2024-10-16T13:05:46.195972Z", "shell.execute_reply": "2024-10-16T13:05:46.195325Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "374 ns ± 3.09 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit b.a = new_value" ] }, { "cell_type": "markdown", "id": "b61ad1d0", "metadata": {}, "source": [ "## Initialization" ] }, { "cell_type": "markdown", "id": "d70f0d54", "metadata": {}, "source": [ "### TreeValue's Initialization" ] }, { "cell_type": "code", "execution_count": 14, "id": "d32a679b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:46.198598Z", "iopub.status.busy": "2024-10-16T13:05:46.198180Z", "iopub.status.idle": "2024-10-16T13:05:51.132389Z", "shell.execute_reply": "2024-10-16T13:05:51.131757Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "608 ns ± 3.41 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit FastTreeValue(_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "24f3707b", "metadata": {}, "source": [ "### Tianshou Batch's Initialization" ] }, { "cell_type": "code", "execution_count": 15, "id": "ac3958df", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:51.134809Z", "iopub.status.busy": "2024-10-16T13:05:51.134366Z", "iopub.status.idle": "2024-10-16T13:05:58.060139Z", "shell.execute_reply": "2024-10-16T13:05:58.059381Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.52 µs ± 101 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit Batch(**_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "1ab82e2d", "metadata": {}, "source": [ "## Deep Copy Operation" ] }, { "cell_type": "code", "execution_count": 16, "id": "210a9442", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:58.062411Z", "iopub.status.busy": "2024-10-16T13:05:58.062200Z", "iopub.status.idle": "2024-10-16T13:05:58.065030Z", "shell.execute_reply": "2024-10-16T13:05:58.064510Z" } }, "outputs": [], "source": [ "import copy" ] }, { "cell_type": "markdown", "id": "5a736274", "metadata": {}, "source": [ "### Deep Copy of TreeValue" ] }, { "cell_type": "code", "execution_count": 17, "id": "f9bcadd6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:05:58.066992Z", "iopub.status.busy": "2024-10-16T13:05:58.066671Z", "iopub.status.idle": "2024-10-16T13:06:08.804937Z", "shell.execute_reply": "2024-10-16T13:06:08.804223Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "132 µs ± 852 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "t3 = FastTreeValue(_TREE_DATA_3)\n", "%timeit copy.deepcopy(t3)" ] }, { "cell_type": "markdown", "id": "bf8be7ea", "metadata": {}, "source": [ "### Deep Copy of Tianshou Batch" ] }, { "cell_type": "code", "execution_count": 18, "id": "91998e6f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:08.807604Z", "iopub.status.busy": "2024-10-16T13:06:08.807207Z", "iopub.status.idle": "2024-10-16T13:06:19.360869Z", "shell.execute_reply": "2024-10-16T13:06:19.360151Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "130 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "b3 = Batch(**_TREE_DATA_3)\n", "%timeit copy.deepcopy(b3)" ] }, { "cell_type": "markdown", "id": "223162fb", "metadata": {}, "source": [ "## Stack, Concat and Split Operation" ] }, { "cell_type": "markdown", "id": "85fa4a73", "metadata": {}, "source": [ "### Performance of TreeValue" ] }, { "cell_type": "code", "execution_count": 19, "id": "a0c2b697", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:19.363574Z", "iopub.status.busy": "2024-10-16T13:06:19.363082Z", "iopub.status.idle": "2024-10-16T13:06:19.366864Z", "shell.execute_reply": "2024-10-16T13:06:19.366296Z" } }, "outputs": [], "source": [ "trees = [FastTreeValue(_TREE_DATA_2) for _ in range(8)]" ] }, { "cell_type": "code", "execution_count": 20, "id": "017ea5a5", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:19.368884Z", "iopub.status.busy": "2024-10-16T13:06:19.368492Z", "iopub.status.idle": "2024-10-16T13:06:19.375504Z", "shell.execute_reply": "2024-10-16T13:06:19.374983Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]],\n", "│ \n", "│ [[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]],\n", "│ \n", "│ [[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]],\n", "│ \n", "│ [[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]],\n", "│ \n", "│ [[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]],\n", "│ \n", "│ [[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]],\n", "│ \n", "│ [[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]],\n", "│ \n", "│ [[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_stack = FastTreeValue.func(subside=True)(torch.stack)\n", "\n", "t_stack(trees)" ] }, { "cell_type": "code", "execution_count": 21, "id": "f8b3f415", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:19.377387Z", "iopub.status.busy": "2024-10-16T13:06:19.377181Z", "iopub.status.idle": "2024-10-16T13:06:21.301879Z", "shell.execute_reply": "2024-10-16T13:06:21.301168Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "23.7 µs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit t_stack(trees)" ] }, { "cell_type": "code", "execution_count": 22, "id": "94b56771", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:21.304307Z", "iopub.status.busy": "2024-10-16T13:06:21.303772Z", "iopub.status.idle": "2024-10-16T13:06:21.310805Z", "shell.execute_reply": "2024-10-16T13:06:21.310156Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896],\n", "│ [-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896],\n", "│ [-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896],\n", "│ [-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896],\n", "│ [-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896],\n", "│ [-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896],\n", "│ [-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896],\n", "│ [-0.3743, -0.9320, -0.5447],\n", "│ [-2.2296, 0.0064, -0.0896]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_cat = FastTreeValue.func(subside=True)(torch.cat)\n", "\n", "t_cat(trees)" ] }, { "cell_type": "code", "execution_count": 23, "id": "5e9c06a6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:21.313046Z", "iopub.status.busy": "2024-10-16T13:06:21.312662Z", "iopub.status.idle": "2024-10-16T13:06:23.079903Z", "shell.execute_reply": "2024-10-16T13:06:23.079159Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "21.7 µs ± 31 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit t_cat(trees)" ] }, { "cell_type": "code", "execution_count": 24, "id": "a3ab5c8f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:23.081934Z", "iopub.status.busy": "2024-10-16T13:06:23.081721Z", "iopub.status.idle": "2024-10-16T13:06:27.200715Z", "shell.execute_reply": "2024-10-16T13:06:27.199978Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50.6 µs ± 245 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "t_split = FastTreeValue.func(rise=True)(torch.split)\n", "tree = FastTreeValue({\n", " 'obs': torch.randn(8, 4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(8, 1,)),\n", " 'reward': torch.rand(8, 1),\n", "})\n", "\n", "%timeit t_split(tree, 1)" ] }, { "cell_type": "markdown", "id": "31c3ec0b", "metadata": {}, "source": [ "### Performance of Tianshou Batch" ] }, { "cell_type": "code", "execution_count": 25, "id": "9ead828a", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:27.203147Z", "iopub.status.busy": "2024-10-16T13:06:27.202749Z", "iopub.status.idle": "2024-10-16T13:06:27.209633Z", "shell.execute_reply": "2024-10-16T13:06:27.209118Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " x: Batch(\n", " c: tensor([[[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]],\n", " \n", " [[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]]]),\n", " ),\n", " a: tensor([[[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]],\n", " \n", " [[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]],\n", " \n", " [[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]],\n", " \n", " [[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]],\n", " \n", " [[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]],\n", " \n", " [[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]],\n", " \n", " [[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]],\n", " \n", " [[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]]]),\n", ")" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batches = [Batch(**_TREE_DATA_2) for _ in range(8)]\n", "\n", "Batch.stack(batches)" ] }, { "cell_type": "code", "execution_count": 26, "id": "ec9037a3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:27.211640Z", "iopub.status.busy": "2024-10-16T13:06:27.211249Z", "iopub.status.idle": "2024-10-16T13:06:32.326778Z", "shell.execute_reply": "2024-10-16T13:06:32.326149Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "63 µs ± 243 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit Batch.stack(batches)" ] }, { "cell_type": "code", "execution_count": 27, "id": "cb8ab77e", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:32.329053Z", "iopub.status.busy": "2024-10-16T13:06:32.328623Z", "iopub.status.idle": "2024-10-16T13:06:32.335276Z", "shell.execute_reply": "2024-10-16T13:06:32.334577Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " x: Batch(\n", " c: tensor([[-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452],\n", " [-0.0560, 1.0876, 0.1732, 2.0784],\n", " [ 1.2565, 0.5128, 0.9535, 0.1456],\n", " [ 1.4677, 0.0500, 0.5396, 0.0452]]),\n", " ),\n", " a: tensor([[-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896],\n", " [-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896],\n", " [-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896],\n", " [-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896],\n", " [-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896],\n", " [-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896],\n", " [-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896],\n", " [-0.3743, -0.9320, -0.5447],\n", " [-2.2296, 0.0064, -0.0896]]),\n", ")" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Batch.cat(batches)" ] }, { "cell_type": "code", "execution_count": 28, "id": "18dfb045", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:32.337276Z", "iopub.status.busy": "2024-10-16T13:06:32.337066Z", "iopub.status.idle": "2024-10-16T13:06:41.993040Z", "shell.execute_reply": "2024-10-16T13:06:41.992275Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "119 µs ± 364 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit Batch.cat(batches)" ] }, { "cell_type": "code", "execution_count": 29, "id": "c6688e51", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:06:41.995294Z", "iopub.status.busy": "2024-10-16T13:06:41.995050Z", "iopub.status.idle": "2024-10-16T13:06:44.274892Z", "shell.execute_reply": "2024-10-16T13:06:44.274148Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "280 µs ± 2.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "batch = Batch({\n", " 'obs': torch.randn(8, 4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(8, 1,)),\n", " 'reward': torch.rand(8, 1)}\n", ")\n", "\n", "%timeit list(Batch.split(batch, 1, shuffle=False, merge_last=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "2539fbd9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }