这篇教程pytorch读取csv数据集写得很实用,希望能帮到您。 pytorch读取csv数据集
文章标签: 深度学习 pytorch python 版权 华为云开发者联盟 该内容已被华为云开发者联盟社区收录 加入社区
在pytorch中已经包含了许多的内置数据集,我们可以很简单的调用其内置的,但是在现实的过程之中我们往往会使用自己的数据集。这就使得读取自己的数据集并进行训练会有很大的问题。
因此对于csv格式的数据集合,以下图为例
每一份csv文件为一个样本,对应的标签数据也是使用csv格式进行存储。
对于这种情况,我们可以使用重写dataset类来解决这个问题,利用迭代的方式依次读取对应的data和label。
代码如下:
class myDataSet(Dataset): def __init__(self, data_dir, label_dir, transform=None): """ :param data_dir: 数据文件路径 :param label_dir: 标签文件路径 :param transform: transform操作 """ self.transform = transform # 读文件夹下每个数据文件名称 #os.listdir读取文件夹内的文件名称 self.file_name = os.listdir(data_dir) # 读标签文件夹下的数据名称 self.label_name = os.listdir(label_dir) self.data_path = [] self.label_path = [] #让每一个文件的路径拼接起来 for index in range(len(self.file_name)): self.data_path.append(os.path.join(data_dir,self.file_name[index])) self.label_path.append(os.path.join(label_dir, self.label_name[index])) def __len__(self): # 返回数据集长度 return len(self.file_name) def __getitem__(self, index): # 获取每一个数据 #读取数据 data = pd.read_csv(self.data_path[index],header=None) #读取标签 label = pd.read_csv(self.label_path[index],header=None) if self.transform : data = self.transform(data) label = self.transform(label) #转成张量 data = torch.tensor(data.values) label = torch.tensor(label.values) return data, label # 返回数据和标签
重构dataset类之后,读取数据并使用dataloader进行数据的加载
data_dir = r"./data/Circle/BV/" label_dir = r"./data/Circle/DDL/" #读取数据集 train_dataset = myDataSet( data_dir = data_dir, label_dir = label_dir, ) #加载数据集 train_iter = DataLoader(train_dataset)
成功加载数据集之后就可以构建自己的网络来进行训练。
ps:学生新手,如果有不足之处还希望大家多多批评指正。 ———————————————— 版权声明:本文为CSDN博主「VanGogh777」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 原文链接:https://blog.csdn.net/qq_42653159/article/details/124511467 返回列表 PyTorch- 多模态融合 |