Official implementation of orthogonal initialization using PyTorch.

Fills the input Tensor with a (semi) orthogonal matrix, as described in this paper Related Link.
The input tensor must have at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened.

import torch

def orthogonal_(tensor: torch.Tensor, gain: float = 1) -> torch.Tensor:

Initialize a new tensor with normal distribution. The shape is the same as the input tensor.

    rows = tensor.size(0)
    cols = tensor.numel() // rows
    flattened = tensor.new(rows, cols).normal_(0, 1)

If rows < cols, transpose the original tensor for computational efficiency.

    if rows < cols:

Compute the QR factorization, Q is an orthogonal matrix and R is an upper triangular matrix.
Related Link

    q, r = torch.linalg.qr(flattened)

Although Q is orthogonal, each value of Q is not uniformly distributed. To make Q uniform, we can use the equation below: $$Q^* = Q sign(diag(R))$$. Proof for this equation can be viewed in this paper: Related Link.

    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph

If rows < cols, transpose the output tensor to match the shape of original tensor.

    if rows < cols:

Using torch.no_grad() here can make sure that these operations won't be added to the computational graph used by PyTorch's autograd system, thus improving efficiency.

    with torch.no_grad():

Reshape the result and copy the weight from q.


Multiply an optional scaling factor.

    return tensor

Test the orthogonal_ function. We use a weight tensor of convolutional layer and a weight tensor of linear layer as test cases, and check whether the results are correctly orthogonalized.

def test_orthogonal() -> None:

For Conv. weights.

    w1 = torch.empty((4, 4, 3, 3))

Test whether the result is orthogonal.

    w1 = w1.reshape(w1.shape[0], -1).T
    res = w1.T @ w1
    gt = torch.eye(w1.shape[1])
    assert torch.sum((res - gt) ** 2).item() < 1e-9

For Linear weights.

    w2 = torch.empty((4, 4))

Test whether the result is orthogonal.

    res = w2.T @ w2
    gt = torch.eye(w2.shape[1])
    assert torch.sum((res - gt) ** 2).item() < 1e-9

