Appendix E: Auto Differentiation#

Mahmood Amintoosi, Spring 2024

Computer Science Dept, Ferdowsi University of Mashhad

In this notebook, we will explain how the auto-differentiation module of PyTorch works. This module is named Autograd.

We will first present you how you can compute gradient using PyTorch for a specific variable and how to check the value of the gradient. Then we will use the backward function to do the gradient computation. Finally, we will see how to detach a tensor from its computation history and how to tell PyTorch not to keep track of the operations (useful in inference!).

More advanced autograd functions are also explained, but we won’t go through them during the workshop.


import torch
torch.__version__
'1.12.1+cpu'

How does Autograd works ?#

When you do operations on Tensors, PyTorch can keep track of the computation graph in order to be able to backpropagate. To tell PyTorch to record operations performed on a tensor, each tensor has a function called requires_grad_.

If there’s at least one input to an operation that requires gradient, its output will also require gradient. Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.

Inplace operations are non-differentiable. That is why x.zero_() gives an error if x requires gradient computation.

For a tensor x, the underlying data is stored in a tensor that is accessible via x.data. If you do an operation on x.data PyTorch does not add the operation to the computation graph.

Function requires_grad#

Each tensor has a property requires_grad specifying whether the gradient should be computed during backward pass.

The function requires_grad_(bool) (notice the trailing _ ) is used to change this property.

A = torch.randint(10, (1,2), dtype=torch.float)
print("A : ", A)

print("A.requires_grad :", A.requires_grad)

A.requires_grad_(True)
print("A.requires_grad :", A.requires_grad)

A.requires_grad_(False)
print("A.requires_grad :", A.requires_grad)
A :  tensor([[3., 1.]])
A.requires_grad : False
A.requires_grad : True
A.requires_grad : False

Backward function#

Here we will see a simple example of how to compute the gradient of a function automatically with pytorch. We will check that it correspond to what we can compute manually.

Let’s look at the function \(f(x) = 5x^2+3\)

x = torch.Tensor([4])
x.requires_grad_()

f = 5 * x ** 2 + 3

f.backward()
print("∂f/∂x| x=4 :", x.grad.item())
∂f/∂x| x=4 : 40.0

Let’s look at the function \(f(x) = 5x^2+3sin(y)\)

x = torch.Tensor([4])
x.requires_grad_()
y = torch.Tensor([0])
y.requires_grad_()

f = 5 * x**2 + 3*torch.sin(y)

f.backward()
print("∂f/∂x| x=4 :", x.grad.item())
print("∂f/∂y| y=0 :", y.grad.item())
∂f/∂x| x=4 : 40.0
∂f/∂y| y=0 : 3.0

Using autograd.grad#

from torch import autograd
x = torch.Tensor([4])
x.requires_grad_()
y = torch.Tensor([0])
y.requires_grad_()

f = 5 * x**2 + 3*torch.sin(y)

df_dx = autograd.grad(f, x)[0]
print("∂f/∂x| x=4 :", df_dx.item())

df_dy = autograd.grad(f, y)[0]
print("∂f/∂y| y=0 :", df_dy.item())
∂f/∂x| x=4 : 40.0
∂f/∂y| y=0 : 3.0
f = 5 * x**2 + 3*torch.sin(y)
print(type(autograd.grad(f, x)))
f = 5 * x**2 + 3*torch.sin(y)
print(autograd.grad(f, x))

f = 5 * x**2 + 3*torch.sin(y)
print(type(autograd.grad(f, x)[0]))
f = 5 * x**2 + 3*torch.sin(y)
print(autograd.grad(f, x)[0])
<class 'tuple'>
(tensor([40.]),)
<class 'torch.Tensor'>
tensor([40.])

Second Derivative#

use create_graph=True

x = torch.Tensor([4])
x.requires_grad_()
y = torch.Tensor([1])
y.requires_grad_()

f = 5 * x**2 * y 

df_dx = autograd.grad(f, x, create_graph=True)[0]
print("∂f/∂x| x=4 :", df_dx.item())
d2f_dx2 = autograd.grad(df_dx, x)[0]
print("∂2f/∂x2| x=4 :", d2f_dx2.item())

df_dy = autograd.grad(f, y, create_graph=True)[0]
print("∂f/∂y| y=1 :", df_dy.item())

# allow_unused=True
d2f_dy2 = autograd.grad(df_dy, y, allow_unused=True)[0]
# print("∂2f/∂y2| y=1 :", d2f_dy2.item()) # Error
print("∂2f/∂y2| y=1 :", d2f_dy2)
∂f/∂x| x=4 : 40.0
∂2f/∂x2| x=4 : 10.0
∂f/∂y| y=1 : 80.0
∂2f/∂y2| y=1 : None

Let’s look at the function \(f(x, y) = \sin\big( \langle x , y \rangle \big)\)

which is equal to \(\sin(\sum_i x_iy_i)\)

X = torch.Tensor([1, 2, 3]).requires_grad_(True)
Y = torch.Tensor([5, 6, 7]).requires_grad_(True)

f = torch.sin(torch.dot(X,Y))
print("f =", f)
f = tensor(0.2964, grad_fn=<SinBackward0>)

We simply need to call the backward function on \(f\).

The backward function will automatically compute all the gradients of \(f\) wrt. the inputs using the chain rule!

# Gradient is populated by the backward function

f.backward()
print("\n-- Backward --\n")
print("X.grad :", X.grad)
print("Manual Derivative:", Y*torch.cos(torch.dot(X,Y)))
print("Y.grad :", Y.grad)
-- Backward --

X.grad : tensor([4.7754, 5.7304, 6.6855])
Manual Derivative: tensor([4.7754, 5.7304, 6.6855], grad_fn=<MulBackward0>)
Y.grad : tensor([0.9551, 1.9101, 2.8652])

And by autograd.grad

f = torch.sin(torch.dot(X,Y))

df_dx = autograd.grad(f, X, create_graph=True)[0]
print("∂f/∂x :", df_dx)

df_dy = autograd.grad(f, Y)[0]
print("∂f/∂y:", df_dy)
∂f/∂x : tensor([4.7754, 5.7304, 6.6855], grad_fn=<MulBackward0>)
∂f/∂y: tensor([0.9551, 1.9101, 2.8652])

Now let’s compute it manually !#

  • \(f\) can be written as a composite function \(f = h \circ g\)

    \(h(z) = \sin(z)\) with derivative \(\dfrac{d h}{d z}(z) = \cos(z)\)

    \(g(x, y) = \langle x , y \rangle\)

We know that: \(\dfrac{\partial }{\partial x}(x^Ty) = \dfrac{\partial }{\partial x}(y^Tx) = y\)
See Wikipedia: Matrix Calculus

  • Using the chain rule, we can easily get the derivative of \(f(x, y) = \sin\big( \langle x , y \rangle \big)\) w.r.t. \(x\) and \(y\):

\(\dfrac{d f }{d x} (x,y) = \cos\big( \langle x , y \rangle \big) \cdot y \)

and

\(\dfrac{d f }{d y} (x,y) = \cos\big( \langle x , y \rangle \big) \cdot x \)

df_dx_man = torch.cos(torch.dot(X,Y)) * Y
print("df / dx = ", df_dx_man)
df / dx =  tensor([4.7754, 5.7304, 6.6855], grad_fn=<MulBackward0>)
df_dy_man = torch.cos(torch.dot(X,Y)) * X
print("df / dy = ", df_dy_man)
df / dy =  tensor([0.9551, 1.9101, 2.8652], grad_fn=<MulBackward0>)
print(df_dx)
print(df_dx_man)
tensor([4.7754, 5.7304, 6.6855], grad_fn=<MulBackward0>)
tensor([4.7754, 5.7304, 6.6855], grad_fn=<MulBackward0>)

Success !

Leaf Variable#

A variable that was created by the user and was therefore not the result of any operation is called a leaf variable.
All variables that have the requires_grad property to False are also considered as leaf variable.

A = torch.Tensor([[1, 2], [3, 4]]).requires_grad_()
B = torch.Tensor([[1, 2], [3, 4]]).requires_grad_() + 2  # B is the result of an operation (+)
C = 5 * A  # C is the result of an operation (*)
D = torch.Tensor([[1, 2], [3, 4]])
print("A.is_leaf :", A.is_leaf)
print("B.is_leaf :", B.is_leaf)
print("C.is_leaf :", C.is_leaf)
print("D.is_leaf :", D.is_leaf)
A.is_leaf : True
B.is_leaf : False
C.is_leaf : False
D.is_leaf : True

Detach function#

A variable can have a long computation history, but you may want to consider it as a new leaf variable without history.

For that, you can use the detach function, which detaches the tensor from its history.

A = torch.Tensor([1, 2]).requires_grad_()
B = A.mean()

print("B : ", B)
print("B.grad_fn :", B.grad_fn)
print("B.is_leaf :", B.is_leaf)
B.backward()
B :  tensor(1.5000, grad_fn=<MeanBackward0>)
B.grad_fn : <MeanBackward0 object at 0x000002150CEA32C8>
B.is_leaf : False
B.detach_()
print("\n -- B.detach_() -- \n")

print("B : ", B)
print("B.grad_fn :", B.grad_fn)
print("B.is_leaf :", B.is_leaf)
# This won't work since B has no history.
# B.backward()
 -- B.detach_() -- 

B :  tensor(1.5000)
B.grad_fn : None
B.is_leaf : True

No_grad function#

At inference time, you don’t want Pytorch to build a computation graph. This can be achieved by wrapping your inference code into the with torch.no_grad() context manager.

x = torch.randn(3, requires_grad=True)
print("x.requires_grad : ", x.requires_grad)

y = (x ** 2)
print("y.requires_grad : ", y.requires_grad)

with torch.no_grad():
    y = (x ** 2)
    print("y.requires_grad : ", y.requires_grad)
x.requires_grad :  True
y.requires_grad :  True
y.requires_grad :  False

Note: Autograd in previous PyTorch versions#

In older versions of PyTorch, one had to wrap a Tensor into a Autograd object called Variable.

Variable was a thin wrapper around a Tensor object, that also held the gradient w.r.t. to it, and a reference to a function that created it. This reference allowed retracing the whole chain of operations that created the data.

Now, Tensors are by default Variable and we don’t need to worry about this anymore, but you may still encounter it in some “old” code.

# from torch.autograd import Variable

# x = Variable(torch.randn(5, 5))
# x

Advanced concepts of Autograd#

The following concepts are more advanced and may want to skip it for now.
We won’t go through them, but there are here for you to come back to later when you feel more comfortable with pytorch.
You can also check the Pytorch Doc.

Retain Grad#

When doing the backward pass, Autograd computes the gradient of the output with respect to every intermediate variables. However, by default, only gradients of variables that were created by the user (leaf) and have the requires_grad property to True are saved.

Indeed, most of the time when training a model you only need the gradient of a loss w.r.t. to your model parameters.

x = torch.Tensor([4])
x.requires_grad_()

f = 5 * x ** 2 + 3

f.backward()
print("∂f/∂x| x=4 :", x.grad.item())
∂f/∂x| x=4 : 40.0
A = torch.Tensor([[1, 2], [3, 4]])
A.requires_grad_()

B = 5 * (A + 3)
C = B.mean()

# print("A.grad :", A.grad)
# print("B.grad :", B.grad)
C.backward()
print("\n-- Backward --\n")
print("A.grad :", A.grad)
print("B.grad :", B.grad)
-- Backward --

A.grad : tensor([[1.2500, 1.2500],
        [1.2500, 1.2500]])
B.grad : None
C:\Programs\Anaconda3\envs\ptch\lib\site-packages\torch\_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten\src\ATen/core/TensorBody.h:417.)
  return self._grad
A = torch.Tensor([[1, 2], [3, 4]])
A.requires_grad_()

B = 5 * (A + 3)
B.retain_grad()  # <----- This line let us have access to gradient wrt. B after the backward pass
C = B.mean()


print("A.grad :", A.grad)
print("B.grad :", B.grad)
C.backward()
print("\n-- Backward --\n")
print("A.grad :", A.grad)
print("B.grad :", B.grad)
A.grad : None
B.grad : None

-- Backward --

A.grad : tensor([[1.2500, 1.2500],
        [1.2500, 1.2500]])
B.grad : tensor([[0.2500, 0.2500],
        [0.2500, 0.2500]])

Gradient accumulation#

You can backward a first time and get a gradient for A, then do some other computation using A and then backward again.
Gradients will get accumulated in A.

A = torch.Tensor([[1, 2], [3, 4]]).requires_grad_()

print("A.grad :", A.grad)

B = 5 * (A + 3)
C = B.mean()
C.backward()

print("\n-- Backward --\n")
print("A.grad :", A.grad)

B = 5 * (A + 3)
C = B.mean()
C.backward()

print("\n-- Backward --\n")
print("A.grad :", A.grad)
A.grad : None

-- Backward --

A.grad : tensor([[1.2500, 1.2500],
        [1.2500, 1.2500]])

-- Backward --

A.grad : tensor([[2.5000, 2.5000],
        [2.5000, 2.5000]])

Under the hood…#

This part is to give a glimpse of how it works under the hood. We don’t need to do such inspection in practice.
Here, we have a look at the computation graph that autograd builds on the fly.

A = torch.Tensor([[1, 2], [3, 4]])
A.requires_grad_()

B = 5 * (A + A)
C = B.mean()

Each tensor has a gradient function.

print(A.grad_fn)
print(B.grad_fn)
print(C.grad_fn)
None
<MulBackward0 object at 0x000001565DBD7348>
<MeanBackward0 object at 0x000001565DBD7388>

We can also “walk” on the computation graph by calling the next_functions attribute.

grad_fn = C.grad_fn
print(grad_fn)

grad_fn = grad_fn.next_functions
print(grad_fn)

grad_fn = grad_fn[0][0].next_functions
print(grad_fn)

grad_fn = grad_fn[0][0].next_functions
print(grad_fn)
<MeanBackward0 object at 0x000001565DBC4788>
((<MulBackward0 object at 0x000001565DBC4948>, 0),)
((<AddBackward0 object at 0x000001565DBC4788>, 0), (None, 0))
((<AccumulateGrad object at 0x000001565DBC4948>, 0), (<AccumulateGrad object at 0x000001565DBC4948>, 0))