您当前的位置:首页 > IT编程 > pytorch
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 |

自学教程:pytorch读取csv数据集

51自学网 2023-11-04 22:08:13
  pytorch
这篇教程pytorch读取csv数据集写得很实用,希望能帮到您。
pytorch读取csv数据集

文章标签: 深度学习 pytorch python
版权
华为云开发者联盟 该内容已被华为云开发者联盟社区收录
加入社区

在pytorch中已经包含了许多的内置数据集,我们可以很简单的调用其内置的,但是在现实的过程之中我们往往会使用自己的数据集。这就使得读取自己的数据集并进行训练会有很大的问题。

因此对于csv格式的数据集合,以下图为例

 

每一份csv文件为一个样本,对应的标签数据也是使用csv格式进行存储。

对于这种情况,我们可以使用重写dataset类来解决这个问题,利用迭代的方式依次读取对应的data和label。

代码如下:

    class myDataSet(Dataset):
        def __init__(self, data_dir, label_dir, transform=None):
            """
            :param data_dir: 数据文件路径
            :param label_dir: 标签文件路径
            :param transform: transform操作
            """
            self.transform = transform
            # 读文件夹下每个数据文件名称
            #os.listdir读取文件夹内的文件名称
            self.file_name = os.listdir(data_dir)
            # 读标签文件夹下的数据名称
            self.label_name = os.listdir(label_dir)
     
            self.data_path = []
            self.label_path = []
     
            #让每一个文件的路径拼接起来
            for index in range(len(self.file_name)):
                self.data_path.append(os.path.join(data_dir,self.file_name[index]))
                self.label_path.append(os.path.join(label_dir, self.label_name[index]))
     
        def __len__(self):
            # 返回数据集长度
            return len(self.file_name)
     
        def __getitem__(self, index):
            # 获取每一个数据
     
            #读取数据
            data = pd.read_csv(self.data_path[index],header=None)
            #读取标签
            label = pd.read_csv(self.label_path[index],header=None)
     
            if self.transform :
                data = self.transform(data)
                label = self.transform(label)
     
            #转成张量
            data = torch.tensor(data.values)
            label = torch.tensor(label.values)
     
            return data, label  # 返回数据和标签

重构dataset类之后,读取数据并使用dataloader进行数据的加载

        data_dir = r"./data/Circle/BV/"
        label_dir = r"./data/Circle/DDL/"
     
        #读取数据集
        train_dataset = myDataSet(
            data_dir = data_dir,
            label_dir = label_dir,
        )
        #加载数据集
        train_iter = DataLoader(train_dataset)

成功加载数据集之后就可以构建自己的网络来进行训练。

ps:学生新手,如果有不足之处还希望大家多多批评指正。
————————————————
版权声明:本文为CSDN博主「VanGogh777」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_42653159/article/details/124511467
返回列表
PyTorch- 多模态融合
51自学网自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1