Apply into Scikit-Learn

Actually, TreeValue can be used in practice with not only numpy or torch library, such as scikit-learn. In the following part, a demo of PCA to tree-structured arrays will be shown.

In the field of traditional machine learning, PCA (Principal Component Analysis) is often used to preprocess data, by normalizing the data range, and trying to reduce the dimensionality of the data, so as to reduce the complexity of the input data and improve machine learning’s efficiency and quality. Just as the following image

PCA Principle

PCA in a nutshell. Source: Lavrenko and Sutton 2011, slide 13.

In the scikit-learn library, the PCA class is provided to support this function, and the function fit_transform can be used to simplify the data. For a set of np.array format data that presents a tree structure, we can implement the operation support for the tree structure by quickly wrapping the function fit_transform. The specific code is as follows

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np
from sklearn.decomposition import PCA

from treevalue import FastTreeValue

fit_transform = FastTreeValue.func()(lambda x: PCA(min(*x.shape)).fit_transform(x))

if __name__ == '__main__':
    data = FastTreeValue({
        'a': np.random.randint(-5, 15, (4, 3)),
        'x': {
            'c': np.random.randint(-15, 5, (5, 4)),
        }
    })
    print("Original int data:")
    print(data)

    pdata = fit_transform(data)
    print("Fit transformed data:")
    print(pdata)

The output should be

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
Original int data:
<FastTreeValue 0x7f7b0d8d0b20>
├── 'a' --> array([[-2,  0, 14],
│                  [ 2,  3, 12],
│                  [ 9,  1, 14],
│                  [ 7,  2, 11]])
└── 'x' --> <FastTreeValue 0x7f7b0d953ee0>
    └── 'c' --> array([[  3, -13,  -7,  -9],
                       [-11,  -5,   2,  -2],
                       [ -8,  -4,   2,  -8],
                       [  4,   4,  -1, -10],
                       [  1, -15,   1,   2]])

Fit transformed data:
<FastTreeValue 0x7f7afbabaac0>
├── 'a' --> array([[ 6.17617397,  1.25181164,  0.31676925],
│                  [ 1.8150677 , -1.7351034 , -0.71235205],
│                  [-4.83435138,  1.82202575, -0.348954  ],
│                  [-3.15689028, -1.33873399,  0.74453681]])
└── 'x' --> <FastTreeValue 0x7f7b0ca58e80>
    └── 'c' --> array([[  5.709154  ,  -8.19568333,   4.68254226,  -0.47972205],
                       [ -1.56860411,   9.65983312,   0.53895793,  -1.50220825],
                       [ -4.47633238,   4.89169311,   2.64620777,   1.71795561],
                       [-10.46499906,  -7.17057088,  -3.33562913,  -0.24553269],
                       [ 10.80078154,   0.81472799,  -4.53207884,   0.50950739]])

For further information, see the links below: