在自己数据集上根据指定比例,生成测试集和训练集,并写入txt文件
import os
import numpy as np
abs_str = '' #绝对路径 + 文件名
dirname = '' #源文件所在目录
import numpy as np
def get_file(dir):
file_list = []
label_list = []
for (index,item) in enumerate(os.listdir(dir)):
imagDir = os.path.join(os.path.abspath(dir),item)
if(os.path.isdir(imagDir)):
for image in os.listdir(imagDir):
if os.path.isfile(os.path.join(imagDir,image)):
file_list.append(os.path.join(item,image))
label_list.append(index)
return file_list,label_list
if __name__=="__main__":
# file_list,label_list = get_file(dirname)
# file_handle = open('train_test_split.txt',mode='w')
# for i,j in enumerate(label_list):
# file_handle.write('{} '.format(i + 1))
# file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
# # file_handle.write(j) # images.txt
# file_handle.write('\n')
# file_handle.close()
# ------------------------split-test-spilt-V1
nums = np.ones(5378, dtype=int)
test_size = int(0.8 * len(nums))
nums[:test_size] = 0
np.random.shuffle(nums)
file_handle = open('train_test_split.txt', mode='w')
for i,j in enumerate(nums):
file_handle.write('{} '.format(i + 1))
# file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
file_handle.write('{} '.format(j)) # train_test_split.txt
# file_handle.write(j) # images.txt
file_handle.write('\n')
file_handle.close()
发现一点问题,不能简单的根据一个随机数进行划分,存在一种可能是在某一个类中没有取到训练或者测试数据,有问题,因此还是需要进行遍历每一个文件夹,有了如下的更新:
import os
import numpy as np
import random
abs_str = '' #绝对路径 + 文件名
dirname = '' #源文件所在目录
def get_file(dir):
file_list = []
label_list = []
for (index,item) in enumerate(os.listdir(dir)):
imagDir = os.path.join(os.path.abspath(dir),item)
if(os.path.isdir(imagDir)):
for image in os.listdir(imagDir):
if os.path.isfile(os.path.join(imagDir,image)):
file_list.append(os.path.join(item,image))
label_list.append(index)
return file_list,label_list
def get_train_test(dir,split_rate):
train_test_list = []
for (index,item) in enumerate(os.listdir(dir)):
imagDir = os.path.join(os.path.abspath(dir),item)
if(os.path.isdir(imagDir)):
print('imagDir', imagDir)
images = os.listdir(imagDir)
num = len(images)
eval_index = random.sample(images, k=int(num * split_rate))
for index, image in enumerate(images):
if image in eval_index:
# 将分配至验证集中的文件复制到相应目录
train_test_list.append(0)
else:
# 将分配至训练集中的文件复制到相应目录
train_test_list.append(1)
print()
return train_test_list
if __name__=="__main__":
# file_list,label_list = get_file(dirname)
# file_handle = open('train_test_split.txt',mode='w')
# for i,j in enumerate(label_list):
# file_handle.write('{} '.format(i + 1))
# file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
# # file_handle.write(j) # images.txt
# file_handle.write('\n')
# file_handle.close()
# ------------------------split-test-spilt-V1
# nums = np.zeros(5378, dtype=int)
# test_size = int(0.8 * len(nums))
# nums[:test_size] = 1
# np.random.shuffle(nums)
# file_handle = open('train_test_split.txt', mode='w')
# for i,j in enumerate(nums):
# file_handle.write('{} '.format(i + 1))
# # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
# file_handle.write('{} '.format(j)) # train_test_split.txt
# # file_handle.write(j) # images.txt
# file_handle.write('\n')
# file_handle.close()
# ------------------------split-test-spilt-V2
nums = get_train_test(dirname,0.2)
file_handle = open('train_test_split.txt', mode='w')
for i,j in enumerate(nums):
file_handle.write('{} '.format(i + 1))
# file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
file_handle.write('{} '.format(j)) # train_test_split.txt
# file_handle.write(j) # images.txt
file_handle.write('\n')
file_handle.close()
n = str(nums).count('1')
m = str(nums).count('0')
print('nums',nums)
print('train',n)
print('value',m)