Apply into Scikit-Learn¶
Actually, TreeValue can be used in practice with not only numpy or torch library, such as scikit-learn.
In the following part, a demo of PCA to tree-structured arrays will be shown.
In the field of traditional machine learning, PCA (Principal Component Analysis) is often used to preprocess data, by normalizing the data range, and trying to reduce the dimensionality of the data, so as to reduce the complexity of the input data and improve machine learning’s efficiency and quality. Just as the following image
 
PCA in a nutshell. Source: Lavrenko and Sutton 2011, slide 13.¶
In the scikit-learn library, the PCA class is provided to support this function, and the function fit_transform
can be used to simplify the data. For a set of np.array format data that presents a tree structure,
we can implement the operation support for the tree structure by quickly wrapping the function fit_transform.
The specific code is as follows
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import numpy as np from sklearn.decomposition import PCA from treevalue import FastTreeValue fit_transform = FastTreeValue.func()(lambda x: PCA(min(*x.shape)).fit_transform(x)) if __name__ == '__main__': data = FastTreeValue({ 'a': np.random.randint(-5, 15, (4, 3)), 'x': { 'c': np.random.randint(-15, 5, (5, 4)), } }) print("Original int data:") print(data) pdata = fit_transform(data) print("Fit transformed data:") print(pdata) | 
The output should be
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | Original int data:
<FastTreeValue 0x7f70c1a90b20>
├── 'a' --> array([[ 2, -4,  7],
│                  [-2,  9,  9],
│                  [ 1, -1, 11],
│                  [ 3, -2, 13]])
└── 'x' --> <FastTreeValue 0x7f70c1b13ee0>
    └── 'c' --> array([[-12, -14,   2,  -2],
                       [-13, -15, -15,   2],
                       [ -6,   1,   2, -12],
                       [  4,   4, -15,  -3],
                       [-11,   2, -15,  -2]])
Fit transformed data:
<FastTreeValue 0x7f70afcbaac0>
├── 'a' --> array([[-4.38193083,  3.31911943, -0.17924406],
│                  [ 9.05831985,  0.42852711, -0.1149177 ],
│                  [-1.47241942, -0.77671989,  0.69187228],
│                  [-3.20396959, -2.97092665, -0.39771052]])
└── 'x' --> <FastTreeValue 0x7f70c0c18e80>
    └── 'c' --> array([[-13.41632895,   5.43366804,  -2.25429746,   1.7074871 ],
                       [ -9.37879213, -11.20729905,  -0.73347577,  -1.67830858],
                       [  2.73235784,  14.03489906,   1.74025994,  -1.49910784],
                       [ 14.71844268,  -2.81717602,  -5.1786862 ,   0.36355692],
                       [  5.34432055,  -5.44409203,   6.42619949,   1.10637239]])
 | 
For further information, see the links below: