suyc's blog

哎,什么时候才能把英语学好啊🙇‍~

PyTorch backward 参数解读

官网Doc链接:link

1
backward(gradient=None, retain_graph=None, create_graph=False)

看一下第一个参数的意思是什么。其实不管Tensor中是几个变量,都可以传入参数,但是一般情况下,我们的out都是只有一个标量的tensor,可以不传入参数。

举个例子,通常我们在求导的时候都是对一个变量求导,例如

所以当我们对求导的时候很自然的可以使用链式法则

但是当我们开始求导的位置是时,如torch.Tensor([H,G]),程序就并不知道这个要从那里开始了,这两个参数有什么关系,所以我们就得给这两个变量假想一个关系,假如存在一个关系

这个关系是什么不重要,重要的是我们在backward中传入的参数就是一个二元组,分别是

这两个参数分别是

的系数。程序最后就可以把这两个关系用上,继续求导。

所以在官网的教程中写道:

out.backward() is equivalent to out.backward(torch.tensor(1.))

也就不难理解了。也就是说可能存在一个关系

在开始求导的时候,这个系数的值