您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 |

自学教程:自定义数据集的训练测试txt文件生成

51自学网 2023-09-12 13:05:06
  python
这篇教程自定义数据集的训练测试txt文件生成写得很实用,希望能帮到您。

自定义数据集的训练测试txt文件生成

在自己数据集上根据指定比例,生成测试集和训练集,并写入txt文件

import os
import numpy as np

abs_str = ''            #绝对路径 + 文件名
dirname = ''            #源文件所在目录

import numpy as np


def get_file(dir):
    file_list = []
    label_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            for image in os.listdir(imagDir):
                if os.path.isfile(os.path.join(imagDir,image)):
                    file_list.append(os.path.join(item,image))
                    label_list.append(index)
    return file_list,label_list

if __name__=="__main__":
    # file_list,label_list = get_file(dirname)
    # file_handle = open('train_test_split.txt',mode='w')
    # for i,j in enumerate(label_list):
    #     file_handle.write('{} '.format(i + 1))
    #     file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V1	
    nums = np.ones(5378, dtype=int)
    test_size = int(0.8 * len(nums))
    nums[:test_size] = 0
    np.random.shuffle(nums)
    file_handle = open('train_test_split.txt', mode='w')
    for i,j in enumerate(nums):
        file_handle.write('{} '.format(i + 1))
        # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
        file_handle.write('{} '.format(j)) # train_test_split.txt
        # file_handle.write(j) # images.txt
        file_handle.write('\n')
    file_handle.close()

发现一点问题,不能简单的根据一个随机数进行划分,存在一种可能是在某一个类中没有取到训练或者测试数据,有问题,因此还是需要进行遍历每一个文件夹,有了如下的更新:

import os
import numpy as np
import random
abs_str = ''            #绝对路径 + 文件名
dirname = ''            #源文件所在目录


def get_file(dir):
    file_list = []
    label_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            for image in os.listdir(imagDir):
                if os.path.isfile(os.path.join(imagDir,image)):
                    file_list.append(os.path.join(item,image))
                    label_list.append(index)
    return file_list,label_list

def get_train_test(dir,split_rate):
    train_test_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            print('imagDir', imagDir)
            images = os.listdir(imagDir)
            num = len(images)
            eval_index = random.sample(images, k=int(num * split_rate))
            for index, image in enumerate(images):
                if image in eval_index:
                    # 将分配至验证集中的文件复制到相应目录
                    train_test_list.append(0)
                else:
                    # 将分配至训练集中的文件复制到相应目录
                    train_test_list.append(1)
            print()
    return train_test_list
if __name__=="__main__":
    # file_list,label_list = get_file(dirname)
    # file_handle = open('train_test_split.txt',mode='w')
    # for i,j in enumerate(label_list):
    #     file_handle.write('{} '.format(i + 1))
    #     file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V1
    # nums = np.zeros(5378, dtype=int)
    # test_size = int(0.8 * len(nums))
    # nums[:test_size] = 1
    # np.random.shuffle(nums)
    # file_handle = open('train_test_split.txt', mode='w')
    # for i,j in enumerate(nums):
    #     file_handle.write('{} '.format(i + 1))
    #     # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     file_handle.write('{} '.format(j)) # train_test_split.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V2
    nums = get_train_test(dirname,0.2)
    file_handle = open('train_test_split.txt', mode='w')
    for i,j in enumerate(nums):
        file_handle.write('{} '.format(i + 1))
        # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
        file_handle.write('{} '.format(j)) # train_test_split.txt
        # file_handle.write(j) # images.txt
        file_handle.write('\n')
    file_handle.close()
    n = str(nums).count('1')
    m = str(nums).count('0')
    print('nums',nums)
    print('train',n)
    print('value',m)
 

返回列表
如何用python生成带图片名称和标签的.txt文件(python代码)
51自学网自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1