{ "cells": [ { "cell_type": "markdown", "id": "9d39d946", "metadata": {}, "source": [ "# Comparison Between TreeValue and Tianshou Batch" ] }, { "cell_type": "markdown", "id": "3c6db2d4", "metadata": {}, "source": [ "In this section, we will take a look at the feature and performance of the [Tianshou Batch](https://github.com/thu-ml/tianshou) library, which is developed by Tsinghua Machine Learning Group." ] }, { "cell_type": "markdown", "id": "069361b0", "metadata": {}, "source": [ "Before starting the comparison, let us define some thing." ] }, { "cell_type": "code", "execution_count": 1, "id": "06fc8d26", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:30.026957Z", "iopub.status.busy": "2024-10-16T13:17:30.026748Z", "iopub.status.idle": "2024-10-16T13:17:31.243062Z", "shell.execute_reply": "2024-10-16T13:17:31.242397Z" } }, "outputs": [], "source": [ "import torch \n", "\n", "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}\n", "_TREE_DATA_2 = {\n", " 'a': torch.randn(2, 3), \n", " 'x': {\n", " 'c': torch.randn(3, 4)\n", " },\n", "}\n", "_TREE_DATA_3 = {\n", " 'obs': torch.randn(4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(1,)),\n", " 'reward': torch.rand(1),\n", "}" ] }, { "cell_type": "markdown", "id": "83461b25", "metadata": {}, "source": [ "## Read and Write Operation" ] }, { "cell_type": "markdown", "id": "067b3f73", "metadata": {}, "source": [ "Reading and writing are the two most common operations in the tree data structure based on the data model (TreeValue and Tianshou Batch both belong to this type), so this section will compare the reading and writing performance of these two libraries." ] }, { "cell_type": "markdown", "id": "5d09a5b7", "metadata": {}, "source": [ "### TreeValue's Get and Set" ] }, { "cell_type": "code", "execution_count": 2, "id": "9519c4bb", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:31.246001Z", "iopub.status.busy": "2024-10-16T13:17:31.245485Z", "iopub.status.idle": "2024-10-16T13:17:31.573012Z", "shell.execute_reply": "2024-10-16T13:17:31.572276Z" } }, "outputs": [], "source": [ "from treevalue import FastTreeValue\n", "\n", "t = FastTreeValue(_TREE_DATA_2)" ] }, { "cell_type": "code", "execution_count": 3, "id": "11c37677", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:31.575937Z", "iopub.status.busy": "2024-10-16T13:17:31.575312Z", "iopub.status.idle": "2024-10-16T13:17:31.584283Z", "shell.execute_reply": "2024-10-16T13:17:31.583646Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t" ] }, { "cell_type": "code", "execution_count": 4, "id": "fd70b0b9", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:31.586561Z", "iopub.status.busy": "2024-10-16T13:17:31.586195Z", "iopub.status.idle": "2024-10-16T13:17:31.590623Z", "shell.execute_reply": "2024-10-16T13:17:31.590096Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.a" ] }, { "cell_type": "code", "execution_count": 5, "id": "c18197bd", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:31.592596Z", "iopub.status.busy": "2024-10-16T13:17:31.592226Z", "iopub.status.idle": "2024-10-16T13:17:35.388222Z", "shell.execute_reply": "2024-10-16T13:17:35.387495Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "46.6 ns ± 0.9 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit t.a" ] }, { "cell_type": "code", "execution_count": 6, "id": "bd52f867", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:35.390483Z", "iopub.status.busy": "2024-10-16T13:17:35.390112Z", "iopub.status.idle": "2024-10-16T13:17:35.396151Z", "shell.execute_reply": "2024-10-16T13:17:35.395523Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[-0.3939, 0.9285, -0.0783],\n", "│ [ 0.5034, 0.0010, -0.4377]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_value = torch.randn(2, 3)\n", "t.a = new_value\n", "\n", "t" ] }, { "cell_type": "code", "execution_count": 7, "id": "bbe04d1c", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:35.398269Z", "iopub.status.busy": "2024-10-16T13:17:35.397857Z", "iopub.status.idle": "2024-10-16T13:17:39.475538Z", "shell.execute_reply": "2024-10-16T13:17:39.474861Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50.2 ns ± 0.582 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit t.a = new_value" ] }, { "cell_type": "markdown", "id": "48c49731", "metadata": {}, "source": [ "### Tianshou Batch's Get and Set" ] }, { "cell_type": "code", "execution_count": 8, "id": "f1bb14c1", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:39.477655Z", "iopub.status.busy": "2024-10-16T13:17:39.477429Z", "iopub.status.idle": "2024-10-16T13:17:39.805962Z", "shell.execute_reply": "2024-10-16T13:17:39.805253Z" } }, "outputs": [], "source": [ "from tianshou.data import Batch\n", "\n", "b = Batch(**_TREE_DATA_2)" ] }, { "cell_type": "code", "execution_count": 9, "id": "cb0777c3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:39.808442Z", "iopub.status.busy": "2024-10-16T13:17:39.808123Z", "iopub.status.idle": "2024-10-16T13:17:39.813616Z", "shell.execute_reply": "2024-10-16T13:17:39.812982Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]]),\n", " x: Batch(\n", " c: tensor([[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]]),\n", " ),\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b" ] }, { "cell_type": "code", "execution_count": 10, "id": "43ef8ea3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:39.815786Z", "iopub.status.busy": "2024-10-16T13:17:39.815287Z", "iopub.status.idle": "2024-10-16T13:17:39.819848Z", "shell.execute_reply": "2024-10-16T13:17:39.819202Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.a" ] }, { "cell_type": "code", "execution_count": 11, "id": "b785ab72", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:39.821984Z", "iopub.status.busy": "2024-10-16T13:17:39.821608Z", "iopub.status.idle": "2024-10-16T13:17:43.152495Z", "shell.execute_reply": "2024-10-16T13:17:43.151757Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "41 ns ± 0.293 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" ] } ], "source": [ "%timeit b.a" ] }, { "cell_type": "code", "execution_count": 12, "id": "ad54dc69", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:43.154886Z", "iopub.status.busy": "2024-10-16T13:17:43.154484Z", "iopub.status.idle": "2024-10-16T13:17:43.160067Z", "shell.execute_reply": "2024-10-16T13:17:43.159433Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[-1.0806, 0.8120, -0.5090],\n", " [-1.1575, -0.1895, -0.6689]]),\n", " x: Batch(\n", " c: tensor([[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]]),\n", " ),\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_value = torch.randn(2, 3)\n", "b.a = new_value\n", "\n", "b" ] }, { "cell_type": "code", "execution_count": 13, "id": "29b1d0bf", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:43.162160Z", "iopub.status.busy": "2024-10-16T13:17:43.161751Z", "iopub.status.idle": "2024-10-16T13:17:46.129144Z", "shell.execute_reply": "2024-10-16T13:17:46.128380Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "366 ns ± 3.48 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit b.a = new_value" ] }, { "cell_type": "markdown", "id": "b61ad1d0", "metadata": {}, "source": [ "## Initialization" ] }, { "cell_type": "markdown", "id": "d70f0d54", "metadata": {}, "source": [ "### TreeValue's Initialization" ] }, { "cell_type": "code", "execution_count": 14, "id": "d32a679b", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:46.131581Z", "iopub.status.busy": "2024-10-16T13:17:46.131155Z", "iopub.status.idle": "2024-10-16T13:17:51.477125Z", "shell.execute_reply": "2024-10-16T13:17:51.476443Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "658 ns ± 0.998 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], "source": [ "%timeit FastTreeValue(_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "24f3707b", "metadata": {}, "source": [ "### Tianshou Batch's Initialization" ] }, { "cell_type": "code", "execution_count": 15, "id": "ac3958df", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:51.479259Z", "iopub.status.busy": "2024-10-16T13:17:51.479051Z", "iopub.status.idle": "2024-10-16T13:17:58.430923Z", "shell.execute_reply": "2024-10-16T13:17:58.430307Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.56 µs ± 49.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit Batch(**_TREE_DATA_1)" ] }, { "cell_type": "markdown", "id": "1ab82e2d", "metadata": {}, "source": [ "## Deep Copy Operation" ] }, { "cell_type": "code", "execution_count": 16, "id": "210a9442", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:58.433084Z", "iopub.status.busy": "2024-10-16T13:17:58.432846Z", "iopub.status.idle": "2024-10-16T13:17:58.436022Z", "shell.execute_reply": "2024-10-16T13:17:58.435449Z" } }, "outputs": [], "source": [ "import copy" ] }, { "cell_type": "markdown", "id": "5a736274", "metadata": {}, "source": [ "### Deep Copy of TreeValue" ] }, { "cell_type": "code", "execution_count": 17, "id": "f9bcadd6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:17:58.438015Z", "iopub.status.busy": "2024-10-16T13:17:58.437803Z", "iopub.status.idle": "2024-10-16T13:18:09.156909Z", "shell.execute_reply": "2024-10-16T13:18:09.156167Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "132 µs ± 476 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "t3 = FastTreeValue(_TREE_DATA_3)\n", "%timeit copy.deepcopy(t3)" ] }, { "cell_type": "markdown", "id": "bf8be7ea", "metadata": {}, "source": [ "### Deep Copy of Tianshou Batch" ] }, { "cell_type": "code", "execution_count": 18, "id": "91998e6f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:09.159426Z", "iopub.status.busy": "2024-10-16T13:18:09.158982Z", "iopub.status.idle": "2024-10-16T13:18:19.668371Z", "shell.execute_reply": "2024-10-16T13:18:19.667678Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "129 µs ± 1.33 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "b3 = Batch(**_TREE_DATA_3)\n", "%timeit copy.deepcopy(b3)" ] }, { "cell_type": "markdown", "id": "223162fb", "metadata": {}, "source": [ "## Stack, Concat and Split Operation" ] }, { "cell_type": "markdown", "id": "85fa4a73", "metadata": {}, "source": [ "### Performance of TreeValue" ] }, { "cell_type": "code", "execution_count": 19, "id": "a0c2b697", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:19.670915Z", "iopub.status.busy": "2024-10-16T13:18:19.670430Z", "iopub.status.idle": "2024-10-16T13:18:19.674134Z", "shell.execute_reply": "2024-10-16T13:18:19.673550Z" } }, "outputs": [], "source": [ "trees = [FastTreeValue(_TREE_DATA_2) for _ in range(8)]" ] }, { "cell_type": "code", "execution_count": 20, "id": "017ea5a5", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:19.676237Z", "iopub.status.busy": "2024-10-16T13:18:19.675842Z", "iopub.status.idle": "2024-10-16T13:18:19.682664Z", "shell.execute_reply": "2024-10-16T13:18:19.681969Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]],\n", "│ \n", "│ [[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]],\n", "│ \n", "│ [[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]],\n", "│ \n", "│ [[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]],\n", "│ \n", "│ [[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]],\n", "│ \n", "│ [[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]],\n", "│ \n", "│ [[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]],\n", "│ \n", "│ [[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_stack = FastTreeValue.func(subside=True)(torch.stack)\n", "\n", "t_stack(trees)" ] }, { "cell_type": "code", "execution_count": 21, "id": "f8b3f415", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:19.684751Z", "iopub.status.busy": "2024-10-16T13:18:19.684298Z", "iopub.status.idle": "2024-10-16T13:18:21.669510Z", "shell.execute_reply": "2024-10-16T13:18:21.668873Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "24.4 µs ± 346 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit t_stack(trees)" ] }, { "cell_type": "code", "execution_count": 22, "id": "94b56771", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:21.671699Z", "iopub.status.busy": "2024-10-16T13:18:21.671470Z", "iopub.status.idle": "2024-10-16T13:18:21.678423Z", "shell.execute_reply": "2024-10-16T13:18:21.677856Z" } }, "outputs": [ { "data": { "text/plain": [ "\n", "├── 'a' --> tensor([[-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458],\n", "│ [-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458],\n", "│ [-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458],\n", "│ [-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458],\n", "│ [-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458],\n", "│ [-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458],\n", "│ [-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458],\n", "│ [-0.6042, -0.6698, -1.1385],\n", "│ [-1.0516, -1.4043, -0.1458]])\n", "└── 'x' --> \n", " └── 'c' --> tensor([[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_cat = FastTreeValue.func(subside=True)(torch.cat)\n", "\n", "t_cat(trees)" ] }, { "cell_type": "code", "execution_count": 23, "id": "5e9c06a6", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:21.680337Z", "iopub.status.busy": "2024-10-16T13:18:21.680129Z", "iopub.status.idle": "2024-10-16T13:18:23.495961Z", "shell.execute_reply": "2024-10-16T13:18:23.495334Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "22.3 µs ± 274 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit t_cat(trees)" ] }, { "cell_type": "code", "execution_count": 24, "id": "a3ab5c8f", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:23.498205Z", "iopub.status.busy": "2024-10-16T13:18:23.497760Z", "iopub.status.idle": "2024-10-16T13:18:27.635970Z", "shell.execute_reply": "2024-10-16T13:18:27.635336Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50.9 µs ± 404 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "t_split = FastTreeValue.func(rise=True)(torch.split)\n", "tree = FastTreeValue({\n", " 'obs': torch.randn(8, 4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(8, 1,)),\n", " 'reward': torch.rand(8, 1),\n", "})\n", "\n", "%timeit t_split(tree, 1)" ] }, { "cell_type": "markdown", "id": "31c3ec0b", "metadata": {}, "source": [ "### Performance of Tianshou Batch" ] }, { "cell_type": "code", "execution_count": 25, "id": "9ead828a", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:27.638216Z", "iopub.status.busy": "2024-10-16T13:18:27.637976Z", "iopub.status.idle": "2024-10-16T13:18:27.647606Z", "shell.execute_reply": "2024-10-16T13:18:27.646712Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]],\n", " \n", " [[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]],\n", " \n", " [[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]],\n", " \n", " [[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]],\n", " \n", " [[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]],\n", " \n", " [[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]],\n", " \n", " [[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]],\n", " \n", " [[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]]]),\n", " x: Batch(\n", " c: tensor([[[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]],\n", " \n", " [[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]]]),\n", " ),\n", ")" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batches = [Batch(**_TREE_DATA_2) for _ in range(8)]\n", "\n", "Batch.stack(batches)" ] }, { "cell_type": "code", "execution_count": 26, "id": "ec9037a3", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:27.649840Z", "iopub.status.busy": "2024-10-16T13:18:27.649637Z", "iopub.status.idle": "2024-10-16T13:18:32.871804Z", "shell.execute_reply": "2024-10-16T13:18:32.871159Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "64.2 µs ± 339 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit Batch.stack(batches)" ] }, { "cell_type": "code", "execution_count": 27, "id": "cb8ab77e", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:32.873877Z", "iopub.status.busy": "2024-10-16T13:18:32.873665Z", "iopub.status.idle": "2024-10-16T13:18:32.880242Z", "shell.execute_reply": "2024-10-16T13:18:32.879718Z" } }, "outputs": [ { "data": { "text/plain": [ "Batch(\n", " a: tensor([[-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458],\n", " [-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458],\n", " [-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458],\n", " [-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458],\n", " [-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458],\n", " [-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458],\n", " [-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458],\n", " [-0.6042, -0.6698, -1.1385],\n", " [-1.0516, -1.4043, -0.1458]]),\n", " x: Batch(\n", " c: tensor([[ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659],\n", " [ 0.5048, 0.6089, -2.0348, 1.0639],\n", " [ 0.3563, 0.7882, 1.1929, -1.0652],\n", " [ 0.8315, 0.1529, 1.0504, -0.6659]]),\n", " ),\n", ")" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Batch.cat(batches)" ] }, { "cell_type": "code", "execution_count": 28, "id": "18dfb045", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:32.882209Z", "iopub.status.busy": "2024-10-16T13:18:32.881981Z", "iopub.status.idle": "2024-10-16T13:18:42.594192Z", "shell.execute_reply": "2024-10-16T13:18:42.593444Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "120 µs ± 711 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit Batch.cat(batches)" ] }, { "cell_type": "code", "execution_count": 29, "id": "c6688e51", "metadata": { "execution": { "iopub.execute_input": "2024-10-16T13:18:42.596647Z", "iopub.status.busy": "2024-10-16T13:18:42.596171Z", "iopub.status.idle": "2024-10-16T13:18:44.938001Z", "shell.execute_reply": "2024-10-16T13:18:44.937250Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "287 µs ± 3.57 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "batch = Batch({\n", " 'obs': torch.randn(8, 4, 84, 84),\n", " 'action': torch.randint(0, 6, size=(8, 1,)),\n", " 'reward': torch.rand(8, 1)}\n", ")\n", "\n", "%timeit list(Batch.split(batch, 1, shuffle=False, merge_last=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "2539fbd9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }