masked_select¶
Documentation¶
-
treetensor.torch.
masked_select
(input, mask, *args, reduce=None, **kwargs)[source]¶ Returns a new 1-D tensor which indexes the
input
tensor according to the boolean maskmask
which is a BoolTensor.Examples:
>>> import torch >>> import treetensor.torch as ttorch >>> t = torch.randn(3, 4) >>> t tensor([[ 0.0481, 0.1741, 0.9820, -0.6354], [ 0.8108, -0.7126, 0.1329, 1.0868], [-1.8267, 1.3676, -1.4490, -2.0224]]) >>> ttorch.masked_select(t, t > 0.3) tensor([0.9820, 0.8108, 1.0868, 1.3676]) >>> tt = ttorch.randn({ ... 'a': (2, 3), ... 'b': {'x': (3, 4)}, ... }) >>> tt <Tensor 0x7f0be77bbc88> ├── a --> tensor([[ 1.1799, 0.4652, -1.7895], │ [ 0.0423, 1.0866, 1.3533]]) └── b --> <Tensor 0x7f0be77bbb70> └── x --> tensor([[ 0.8139, -0.6732, 0.0065, 0.9073], [ 0.0596, -2.0621, -0.1598, -1.0793], [-0.0496, 2.1392, 0.6403, 0.4041]]) >>> ttorch.masked_select(tt, tt > 0.3) tensor([1.1799, 0.4652, 1.0866, 1.3533, 0.8139, 0.9073, 2.1392, 0.6403, 0.4041]) >>> ttorch.masked_select(tt, tt > 0.3, reduce=False) <Tensor 0x7fcb64456b38> ├── a --> tensor([1.1799, 0.4652, 1.0866, 1.3533]) └── b --> <Tensor 0x7fcb64456a58> └── x --> tensor([0.8139, 0.9073, 2.1392, 0.6403, 0.4041])
Torch Version Related
This documentation is based on torch.masked_select in torch v2.4.1+cu121. Its arguments’ arrangements depend on the version of pytorch you installed.
If some arguments listed here are not working properly, please check your pytorch’s version with the following command and find its documentation.
1 | python -c 'import torch;print(torch.__version__)' |
The arguments and keyword arguments supported in torch v2.4.1+cu121 is listed below.
Description From Torch v2.4.1+cu121¶
-
torch.
masked_select
(input, mask, *, out=None) → Tensor¶ Returns a new 1-D tensor which indexes the
input
tensor according to the boolean maskmask
which is a BoolTensor.The shapes of the
mask
tensor and theinput
tensor don’t need to match, but they must be broadcastable.Note
The returned tensor does not use the same storage as the original tensor
- Args:
input (Tensor): the input tensor. mask (BoolTensor): the tensor containing the binary mask to index with
- Keyword args:
out (Tensor, optional): the output tensor.
Example:
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], [-1.2035, 1.2252, 0.5002, 0.6248], [ 0.1307, -2.0608, 0.1244, 2.0139]]) >>> mask = x.ge(0.5) >>> mask tensor([[False, False, False, False], [False, True, True, True], [False, False, False, True]]) >>> torch.masked_select(x, mask) tensor([ 1.2252, 0.5002, 0.6248, 2.0139])