Какая польза от torch.no_grad в pytorch?


21

Я новичок в Pytorch и начал с этим кодом GitHub . Я не понимаю комментарий в строке 60-61 в коде "because weights have requires_grad=True, but we don't need to track this in autograd". Я понял, что мы упоминаем requires_grad=Trueпеременные, которые нам нужны для вычисления градиентов для использования автограда, но что это значит "tracked by autograd"?

Ответы:


24

Оболочка "with torch.no_grad ()" временно установила для флага флаг require_grad значение false. Пример из официального руководства по PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Вне:

True
True
False

Я рекомендую вам прочитать все учебники с сайта выше.

В вашем примере: я думаю, что автор не хочет, чтобы PyTorch вычислял градиенты новых определенных переменных w1 и w2, поскольку он просто хочет обновить их значения.


6
with torch.no_grad()

сделает все операции в блоке без градиентов.

В pytorch вы не можете изменить размещение w1 и w2, которые являются двумя переменными с require_grad = True. Я думаю, что избегание изменения размещения w1 и w2 связано с тем, что это приведет к ошибке в расчете обратного распространения. Так как изменение размещения полностью изменит w1 и w2.

Однако, если вы используете это no_grad(), вы можете контролировать новый w1, и новый w2 не имеет градиентов, так как они генерируются операциями, что означает, что вы изменяете только значения w1 и w2, а не часть градиента, они все еще имеют ранее определенную переменную информацию о градиенте и обратное распространение может продолжаться.

Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.