- torch
- torch.Tensor
- Tensor Attributes
- torch.sparse
- torch.cuda
- torch.Storage
- torch.nn
- torch.nn.functional
- torch.autograd
- torch.optim
- torch.nn.init
- torch.distributions
- torch.multiprocessing
- torch.distributed
- torch.utils.bottleneck
- torch.utils.checkpoint
- torch.utils.cpp_extension
- torch.utils.data
- torch.utils.ffi
- torch.utils.model_zoo
- torch.onnx
- torch.legacy
包参考
序列化语义
最佳实践
保存模型的推荐方法
序列化和恢复模型有两种主要方法。
第一个(推荐)只保存和加载模型参数:
torch.save(the_model.state_dict(), PATH)
然后:
the_model = TheModelClass(*args, **kwargs)
the_mdel.load_state_dict(torch.load(PATH))
第二个方法是保存并加载整个模型:
torch.save(the_model, PATH)
然后:
the_model = torch.load(PATH)
然而,在这种情况下,序列化的数据会与特定的类结构和准确的目录结构相绑定,所以在其他项目中使用或经大量重构之后,这些结构可能会以各种方式被破坏。