您现在的位置是:首页 >学无止境 >沐神深度学习paddle自动求导实现网站首页学无止境
沐神深度学习paddle自动求导实现
import paddle
x = paddle.to_tensor([i for i in range(4)], dtype='float32', stop_gradient=False)
x
Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=False, [0., 1., 2., 3.])
paddle.is_grad_enabled()
True
y = 2 * paddle.dot(x, x)
y
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False, [28.])
paddle.autograd.backward(y, retain_graph=True)
x.grad
Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=False, [0. , 4. , 8. , 12.])
x.grad == 4 * x
Tensor(shape=[4], dtype=bool, place=Place(gpu:0), stop_gradient=True, [True, True, True, True])
x.clear_grad()
y = x.sum()
paddle.autograd.backward(y, retain_graph=True)
x.grad
Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=False, [1., 1., 1., 1.])
x.clear_grad()
y = x * x
paddle.autograd.backward(y.sum(), retain_graph=True)
x.grad
Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=False, [0., 2., 4., 6.])
将计算移出记录的计算图
x.clear_grad()
y = x * x
u = y.detach()
z = u * x
paddle.autograd.backward(z.sum(), retain_graph=True)
x.grad == u
Tensor(shape=[4], dtype=bool, place=Place(gpu:0), stop_gradient=True, [True, True, True, True])
x.clear_grad()
paddle.autograd.backward(y.sum(), retain_graph=True)
x.grad == 2 * x
Tensor(shape=[4], dtype=bool, place=Place(gpu:0), stop_gradient=True, [True, True, True, True])
复杂工作流自动求导
def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c
a = paddle.randn(shape=((10,10)),dtype='float32')
a.stop_gradient = False
d = f(a)
paddle.autograd.backward(d, retain_graph=True)
a.grad == d / a
Tensor(shape=[10, 10], dtype=bool, place=Place(gpu:0), stop_gradient=True,
[[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])