这篇教程深入理解Pytorch微调torchvision模型写得很实用,希望能帮到您。
一、简介在本小节,深入探讨如何对torchvision进行微调和特征提取。所有模型都已经预先在1000类的magenet数据集上训练完成。 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。 本节将执行两种类型的迁移学习: - 微调:从预训练模型开始,更新我们新任务的所有模型参数,实质上是重新训练整个模型。
- 特征提取:从预训练模型开始,仅更新从中导出预测的最终图层权重。它被称为特征提取,因为我们使用预训练的CNN作为固定 的特征提取器,并且仅改变输出层。
通常这两种迁移学习方法都会遵循一下步骤: - 初始化预训练模型
- 重组最后一层,使其具有与新数据集类别数相同的输出数
- 为优化算法定义想要的训练期间更新的参数
- 运行训练步骤
二、导入相关包from __future__ import print_functionfrom __future__ import divisionimport torchimport torch.nn as nnimport torch.optim as optimimport numpy as npimport torchvision from torchvision import datasets,models,transformsimport matplotlib.pyplot as pltimport timeimport osimport copyprint("Pytorch version:",torch.__version__)print("torchvision version:",torchvision.__version__) 运行结果 
三、数据输入数据集 Python 中 Shutil 模块详情 详解Python调试神器之PySnooper |