Source code for treetensor.torch.funcs.wrapper

import torch

from .base import doc_from_base, wrap_for_treelize, _is_torch_2

__all__ = [
    'vmap',
]

if _is_torch_2:
    @doc_from_base()
    @wrap_for_treelize()
    def vmap(func, *args, **kwargs):
        return torch.vmap(func, *args, **kwargs)

else:
[docs] def vmap(func, *args, **kwargs): """ .. warning: :method:`treetensor.torch.vmap` is not supported for torch 1.x. """ raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.')