这篇教程Pytorch快速下载预训练模型并修改保存路径写得很实用,希望能帮到您。 【Pytorch】快速下载预训练模型并修改保存路径
首次用Pytorch加载预训练模型,需要在线下载,但是下载速度比较慢。下载后会保存在本地缓存里。如果能直接加载本地下载好的模型就会快了,主要是个修改路径的问题。
所以要提升速度一般有两种方法: 1.修改torch源码,一次性改变下载url 2.将离线模型权重存到缓存文件夹里 参考:pytorch预训练模型的下载地址以及解决下载速度慢的方法 - you-wh - 博客园 参考:【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径_ProLover98的博客-CSDN博客 参考:pytorch 加载(.pth)格式的模型_人工智能_u014264373的博客-CSDN博客(没有修改存储路径) 但是用云服务器时候这两种方法都有点问题。如果预训练模型的下载路径和存储路径能随用随改就最好了。
所以以vgg16为例,本文采用的方法是:
import torch from torchvision import models pthfile = 'file:///mnt/model/vgg16-397923af.pth' #在下载好的pth文件路径前加file:///得到url pthsavefile = '/mnt/vgg16-397923af.pth' #这是模型保存的路径 model = models.vgg.vgg16(pretrained=False, progress=True) #定义一个不需要预训练的模型。如果pretrained=True就会自动下载了 state_dict = torch.utils.model_zoo.load_url(pthfile, model_dir=pthsavefile, map_location=None, progress=True, check_hash=False) # 从pthfile下载到pthsavefile。默认model_dir为none model.load_state_dict(state_dict) # 读取下载好的模型 # 设置好参数就可以train了 model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)
模型可以任意换成别的,比如
models.vgg.vgg16 models.resnet.resnet18 models.resnet.resnext50_32x4d pytorch实现从本地加载 .pth 格式模型 CelebA数据集详细介绍及其属性提取源代码 |