Apply into Scikit-Learn

Actually, TreeValue can be used in practice with not only numpy or torch library, such as scikit-learn. In the following part, a demo of PCA to tree-structured arrays will be shown.

In the field of traditional machine learning, PCA (Principal Component Analysis) is often used to preprocess data, by normalizing the data range, and trying to reduce the dimensionality of the data, so as to reduce the complexity of the input data and improve machine learning’s efficiency and quality. Just as the following image

PCA Principle

PCA in a nutshell. Source: Lavrenko and Sutton 2011, slide 13.

In the scikit-learn library, the PCA class is provided to support this function, and the function fit_transform can be used to simplify the data. For a set of np.array format data that presents a tree structure, we can implement the operation support for the tree structure by quickly wrapping the function fit_transform. The specific code is as follows

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np
from sklearn.decomposition import PCA

from treevalue import FastTreeValue

fit_transform = FastTreeValue.func()(lambda x: PCA(min(*x.shape)).fit_transform(x))

if __name__ == '__main__':
    data = FastTreeValue({
        'a': np.random.randint(-5, 15, (4, 3)),
        'x': {
            'c': np.random.randint(-15, 5, (5, 4)),
        }
    })
    print("Original int data:")
    print(data)

    pdata = fit_transform(data)
    print("Fit transformed data:")
    print(pdata)

The output should be

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
Original int data:
<FastTreeValue 0x7f12b5ded370>
├── 'a' --> array([[ 5, 10,  8],
│                  [ 2, -5,  8],
│                  [ 1, -4, 14],
│                  [ 7, -4,  7]])
└── 'x' --> <FastTreeValue 0x7f12b596ee50>
    └── 'c' --> array([[-10,   2,   2,  -9],
                       [-15,   0, -14,  -6],
                       [ -1,   0, -15,   3],
                       [-10,  -4,  -4, -11],
                       [-12,  -4,   2,   1]])

Fit transformed data:
<FastTreeValue 0x7f12b5d7f6d0>
├── 'a' --> array([[10.87824485,  0.57103045,  0.15878664],
│                  [-4.24316583, -0.82637024,  2.00003898],
│                  [-4.30356117,  4.62611836, -0.87515152],
│                  [-2.33151785, -4.37077857, -1.28367409]])
└── 'x' --> <FastTreeValue 0x7f12a58024c0>
    └── 'c' --> array([[-8.37563731, -0.19822628, -3.58837668, -3.05498916],
                       [ 4.54441555,  7.85342722,  4.07913381, -1.19692756],
                       [13.78198432, -4.28660674, -2.54591226, -0.01554565],
                       [-4.52520671,  3.83400453, -2.6146956 ,  3.57579088],
                       [-5.42555584, -7.20259873,  4.66985073,  0.69167149]])

For further information, see the links below: