Apply into Scikit-Learn¶
Actually, TreeValue
can be used in practice with not only numpy
or torch
library, such as scikit-learn
.
In the following part, a demo of PCA to tree-structured arrays will be shown.
In the field of traditional machine learning, PCA (Principal Component Analysis) is often used to preprocess data, by normalizing the data range, and trying to reduce the dimensionality of the data, so as to reduce the complexity of the input data and improve machine learning’s efficiency and quality. Just as the following image
In the scikit-learn library, the PCA class is provided to support this function, and the function fit_transform
can be used to simplify the data. For a set of np.array
format data that presents a tree structure,
we can implement the operation support for the tree structure by quickly wrapping the function fit_transform
.
The specific code is as follows
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import numpy as np from sklearn.decomposition import PCA from treevalue import FastTreeValue fit_transform = FastTreeValue.func()(lambda x: PCA(min(*x.shape)).fit_transform(x)) if __name__ == '__main__': data = FastTreeValue({ 'a': np.random.randint(-5, 15, (4, 3)), 'x': { 'c': np.random.randint(-15, 5, (5, 4)), } }) print("Original int data:") print(data) pdata = fit_transform(data) print("Fit transformed data:") print(pdata) |
The output should be
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | Original int data: <FastTreeValue 0x7f9f84335b20> ├── 'a' --> array([[ 9, 9, -3], │ [-5, 14, 6], │ [ 4, 7, 0], │ [ 3, -1, 6]]) └── 'x' --> <FastTreeValue 0x7f9f84393ee0> └── 'c' --> array([[ 3, -3, -5, -8], [ -8, -2, -14, -7], [ 3, -12, -15, -4], [-11, -1, -15, 4], [ 3, -7, 0, -3]]) Fit transformed data: <FastTreeValue 0x7f9f726baac0> ├── 'a' --> array([[-5.64361948, -6.12795566, -0.5340586 ], │ [10.90642714, -0.83906228, -0.18254208], │ [-1.91236952, -1.45341863, 0.9580799 ], │ [-3.35043814, 8.42043657, -0.24147922]]) └── 'x' --> <FastTreeValue 0x7f9f836d8e80> └── 'c' --> array([[-7.1110395 , -2.53269706, 3.38732589, 1.71597963], [ 6.54756677, -0.50033091, 5.41738764, -1.31565553], [-2.36661527, 9.68194513, -1.36207401, 0.0618468 ], [12.64056741, -2.52269583, -3.88699477, 0.73711125], [-9.71047941, -4.12622132, -3.55564475, -1.19928214]]) |
For further information, see the links below: