您当前的位置:首页 > IT编程 > 深度学习
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:使用Keras实现多分类输出multi-class classification(两种网络结构)

51自学网 2020-11-03 12:23:20
  深度学习
这篇教程使用Keras实现多分类输出multi-class classification(两种网络结构)写得很实用,希望能帮到您。

使用Keras实现多分类输出multi-class classification(两种网络结构)

如何让一个网络同时分类一张图像的两个独立标签?

一般我们会构建一个输出网络,每一个label作为属性输出;或者构建两个分支网络,针对不同label输出。

1、数据集组成(fashion)

本人的数据集有12类,共计5547张图片。其中有6类是从网上下载获取的,剩余的6类是自己在网上爬虫分类整理得到的。

该数据集主要是有两类信息:颜色(黑色、红色、蓝色、白色)和服饰类型(牛仔裤、连衣裙、短袖、鞋子、包包),具体的数据集内容如下:

黑色连衣裙:black_dress(333张)

黑色牛仔裤:black_jeans(344张)

黑色短袖:black_shirt(436张)

黑色鞋子:black_shoe(534张)

蓝色连衣裙:blue_dress(386张)

蓝色牛仔裤:blue_jeans(356张)

蓝色短袖:blue_shirt(369张)

红色连衣裙:red_dress(384张)

红色短袖:red_shirt(332)

红色鞋子:red_shoe(486)

白色包包:white_bag(747)

白色鞋子:white_shoe(840)

2、构建网络(单输出)

2.1、采用类似vgg的网络结构(SimpleNet)


 
  1.  
    class SimpleNet(object):
  2.  
    def __init__(self, input_shape, classes, finalAct="softmax"):
  3.  
    #default input_shape = (width, height, channel)
  4.  
    self.input_shape = input_shape
  5.  
    self.classes = classes
  6.  
    self.finalAct = finalAct
  7.  
     
  8.  
    #chanDim = inputShape[2]
  9.  
    chanDim = -1
  10.  
    if K.image_data_format() == "channels_first":
  11.  
    chanDim = 1
  12.  
    self.chanDim = chanDim
  13.  
     
  14.  
     
  15.  
    def build_model(self):
  16.  
    model = Sequential()
  17.  
    # CONV => RELU => POOL
  18.  
    model.add(Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding="same", input_shape=self.input_shape))
  19.  
    model.add(Activation("relu"))
  20.  
    model.add(BatchNormalization(axis=self.chanDim))
  21.  
    model.add(MaxPooling2D(pool_size=(3, 3)))
  22.  
    model.add(Dropout(0.25))
  23.  
     
  24.  
    # (CONV => RELU) * 2 => POOL
  25.  
    model.add(Conv2D(64, (3, 3), padding="same"))
  26.  
    model.add(Activation("relu"))
  27.  
    model.add(BatchNormalization(axis=self.chanDim))
  28.  
    model.add(Conv2D(64, (3, 3), padding="same"))
  29.  
    model.add(Activation("relu"))
  30.  
    model.add(BatchNormalization(axis=self.chanDim))
  31.  
    model.add(MaxPooling2D(pool_size=(2, 2)))
  32.  
    model.add(Dropout(0.25))
  33.  
     
  34.  
    # (CONV => RELU) * 2 => POOL
  35.  
    model.add(Conv2D(128, (3, 3), padding="same"))
  36.  
    model.add(Activation("relu"))
  37.  
    model.add(BatchNormalization(axis=self.chanDim))
  38.  
    model.add(Conv2D(128, (3, 3), padding="same"))
  39.  
    model.add(Activation("relu"))
  40.  
    model.add(BatchNormalization(axis=self.chanDim))
  41.  
    model.add(MaxPooling2D(pool_size=(2, 2)))
  42.  
    model.add(Dropout(0.25))
  43.  
     
  44.  
    # (CONV => RELU) * 2 => POOL
  45.  
    model.add(Conv2D(256, (3, 3), padding="same"))
  46.  
    model.add(Activation("relu"))
  47.  
    model.add(BatchNormalization(axis=self.chanDim))
  48.  
    model.add(Conv2D(256, (3, 3), padding="same"))
  49.  
    model.add(Activation("relu"))
  50.  
    model.add(BatchNormalization(axis=self.chanDim))
  51.  
    model.add(MaxPooling2D(pool_size=(2, 2)))
  52.  
    model.add(Dropout(0.25))
  53.  
     
  54.  
    # use global average pooling instead of fc layer
  55.  
    model.add(GlobalAveragePooling2D())
  56.  
    model.add(Activation("relu"))
  57.  
    model.add(BatchNormalization())
  58.  
    model.add(Dropout(0.5))
  59.  
     
  60.  
    # softmax classifier
  61.  
    model.add(Dense(self.classes))
  62.  
    model.add(Activation(self.finalAct))
  63.  
    model.summary()
  64.  
     
  65.  
    return model

说明:该种结构仅能识别上述12类,若是出现了某类其他类型和颜色搭配,如红色包包,则会识别错误。

在多分类中,最常用的就是softmax层。由于标签间是独立的,因此对于一个二分类问题,常用的激活函数是sigmoid函数。

在多标签分类中,大多使用binary_crossentropy损失而不是通常在多类分类中使用的categorical_crossentropy损失函数。

2.2、采用多分枝的网络结构(FashionNet)

该网络结构中一个用于识别类型,一个识别色彩。类型识别的结构可以复杂点,主要是形状识别,因此传入的图片做了灰度化处理;色彩识别比较简单,因此对应的网络结构比较简单。

该结构的好出是可以出数据中没有出现的类型,比如蓝色鞋子、红色包包等,前一个网络结构则无法识别。


 
  1.  
    class FashionNet(object):
  2.  
    def __init__(self, input_shape, category_classes, color_classes, finalAct="softmax"):
  3.  
    #default input_shape = (width, height, channel)
  4.  
    self.input_shape = input_shape
  5.  
    self.category_classes = category_classes
  6.  
    self.color_classes = color_classes
  7.  
    self.finalAct = finalAct
  8.  
     
  9.  
    #chanDim = inputShape[2]
  10.  
    chanDim = -1
  11.  
    if K.image_data_format() == "channels_first":
  12.  
    chanDim = 1
  13.  
    self.chanDim = chanDim
  14.  
     
  15.  
    def build_category_branch(self, inputs):
  16.  
    # convert 3 channel(rgb) input to gray
  17.  
    x = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(inputs)
  18.  
     
  19.  
    #Conv->ReLU->BN->Pool
  20.  
    x = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  21.  
    x = Activation('relu')(x)
  22.  
    x = BatchNormalization(axis=self.chanDim)(x)
  23.  
    x = MaxPooling2D(pool_size=(3,3))(x)
  24.  
     
  25.  
    #(CONV => RELU) * 2 => POOL
  26.  
    x = Conv2D(64, (3, 3), padding="same")(x)
  27.  
    x = Activation("relu")(x)
  28.  
    x = BatchNormalization(axis=self.chanDim)(x)
  29.  
    x = Conv2D(64, (3, 3), padding="same")(x)
  30.  
    x = Activation("relu")(x)
  31.  
    x = BatchNormalization(axis=self.chanDim)(x)
  32.  
    x = MaxPooling2D(pool_size=(2, 2))(x)
  33.  
    x = Dropout(0.25)(x)
  34.  
     
  35.  
    # (CONV => RELU) * 2 => POOL
  36.  
    x = Conv2D(128, (3, 3), padding="same")(x)
  37.  
    x = Activation("relu")(x)
  38.  
    x = BatchNormalization(axis=self.chanDim)(x)
  39.  
    x = Conv2D(128, (3, 3), padding="same")(x)
  40.  
    x = Activation("relu")(x)
  41.  
    x = BatchNormalization(axis=self.chanDim)(x)
  42.  
    x = MaxPooling2D(pool_size=(2, 2))(x)
  43.  
    x = Dropout(0.25)(x)
  44.  
     
  45.  
    # (CONV => RELU) * 2 => POOL
  46.  
    x = Conv2D(256, (3, 3), padding="same")(x)
  47.  
    x = Activation("relu")(x)
  48.  
    x = BatchNormalization(axis=self.chanDim)(x)
  49.  
    x = Conv2D(256, (3, 3), padding="same")(x)
  50.  
    x = Activation("relu")(x)
  51.  
    x = BatchNormalization(axis=self.chanDim)(x)
  52.  
    x = MaxPooling2D(pool_size=(2, 2))(x)
  53.  
    x = Dropout(0.25)(x)
  54.  
     
  55.  
    # use global average pooling instead of fc layer
  56.  
    x = GlobalAveragePooling2D()(x)
  57.  
    x = Activation("relu")(x)
  58.  
    x = BatchNormalization()(x)
  59.  
    x = Dropout(0.5)(x)
  60.  
     
  61.  
    # softmax classifier
  62.  
    x = Dense(self.category_classes)(x)
  63.  
    x = Activation(self.finalAct, name='category_output')(x)
  64.  
     
  65.  
    return x
  66.  
     
  67.  
    def build_color_branch(self, inputs):
  68.  
    #Conv->ReLU->BN->Pool
  69.  
    x = Conv2D(filters=16, kernel_size=(3,3), strides=(1,1), padding='same')(inputs)
  70.  
    x = Activation('relu')(x)
  71.  
    x = BatchNormalization(axis=self.chanDim)(x)
  72.  
    x = MaxPooling2D(pool_size=(3,3))(x)
  73.  
     
  74.  
    #Conv->ReLU->BN->Pool*2
  75.  
    x = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  76.  
    x = Activation('relu')(x)
  77.  
    x = BatchNormalization(axis=self.chanDim)(x)
  78.  
    x = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  79.  
    x = Activation('relu')(x)
  80.  
    x = BatchNormalization(axis=self.chanDim)(x)
  81.  
    x = MaxPooling2D(pool_size=(2,2))(x)
  82.  
    x = Dropout(0.25)(x)
  83.  
     
  84.  
    #Conv->ReLU->BN->Pool*2
  85.  
    x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  86.  
    x = Activation('relu')(x)
  87.  
    x = BatchNormalization(axis=self.chanDim)(x)
  88.  
    x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(x)
  89.  
    x = Activation('relu')(x)
  90.  
    x = BatchNormalization(axis=self.chanDim)(x)
  91.  
    x = MaxPooling2D(pool_size=(2,2))(x)
  92.  
    x = Dropout(0.25)(x)
  93.  
     
  94.  
    x = Flatten()(x)
  95.  
    x = Dense(128)(x)
  96.  
    x = Activation('relu')(x)
  97.  
    x = BatchNormalization()(x)
  98.  
    x = Dropout(0.5)(x)
  99.  
    x = Dense(self.color_classes)(x)
  100.  
    x = Activation(self.finalAct, name='color_output')(x)
  101.  
    return x
  102.  
     
  103.  
    def build_model(self):
  104.  
    input_shape = self.input_shape
  105.  
    inputs = Input(shape=input_shape)
  106.  
    category_branch = self.build_category_branch(inputs)
  107.  
    color_branch = self.build_color_branch(inputs)
  108.  
     
  109.  
    model = Model(inputs=inputs, outputs=[category_branch, color_branch])
  110.  
    model.summary()
  111.  
    return model

3、模型训练

针对两种不同的方式,训练代码中的函数做了如下区分:


 
  1.  
    #! -*- coding:utf-8
  2.  
     
  3.  
    # import the necessary packages
  4.  
    from keras.preprocessing.image import ImageDataGenerator
  5.  
    from keras.optimizers import Adam
  6.  
    from keras.preprocessing.image import img_to_array
  7.  
    from sklearn.preprocessing import MultiLabelBinarizer,LabelBinarizer
  8.  
    from sklearn.model_selection import train_test_split
  9.  
    from cnn import SimpleNet
  10.  
    #from cnn import SmallerInceptionNet
  11.  
    from cnn import FashionNet
  12.  
    import matplotlib.pyplot as plt
  13.  
    from imutils import paths
  14.  
    import numpy as np
  15.  
    import argparse
  16.  
    import random
  17.  
    import pickle
  18.  
    import cv2
  19.  
    import os
  20.  
    from PIL import Image
  21.  
     
  22.  
    # grab the image paths and randomly shuffle them
  23.  
    def load_data(data_dir, img_size):
  24.  
    print("[INFO] loading images...")
  25.  
    if not os.path.exists(data_dir):
  26.  
    return None
  27.  
    imagePaths = sorted(list(paths.list_images(data_dir)))
  28.  
    random.seed(42)
  29.  
    random.shuffle(imagePaths)
  30.  
     
  31.  
    datas = []
  32.  
    labels = []
  33.  
    for imagePath in imagePaths:
  34.  
    image = cv2.imread(imagePath, cv2.IMREAD_UNCHANGED)
  35.  
    if image is None:
  36.  
    print(imagePath)
  37.  
    continue
  38.  
    # convert 8depth to 24 depth
  39.  
    if len(image.shape)==2:
  40.  
    with Image.open(imagePath) as img:
  41.  
    rgb_img = img.convert('RGB')
  42.  
    image = cv2.cvtColor(np.asarray(rgb_img), cv2.COLOR_RGB2BGR)
  43.  
    elif len(image.shape)==3:
  44.  
    if image.shape[2]==4:
  45.  
    image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
  46.  
    elif image.shape[2]==1:
  47.  
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  48.  
     
  49.  
    image = cv2.resize(image, img_size)
  50.  
    image = img_to_array(image)
  51.  
    datas.append(image)
  52.  
     
  53.  
    label = imagePath.split(os.path.sep)[-2].split("_")
  54.  
    labels.append(label)
  55.  
     
  56.  
    # scale the raw pixel intensities to the range [0, 1]
  57.  
    datas = np.array(datas, dtype="float") / 255.0
  58.  
    labels = np.array(labels)
  59.  
    return datas, labels
  60.  
     
  61.  
    def load_data_multilabels(data_dir, img_size):
  62.  
    print("[INFO] loading images...")
  63.  
    if not os.path.exists(data_dir):
  64.  
    return None
  65.  
    imagePaths = sorted(list(paths.list_images(data_dir)))
  66.  
    random.seed(42)
  67.  
    random.shuffle(imagePaths)
  68.  
     
  69.  
    datas = []
  70.  
    category_labels = []
  71.  
    color_labels = []
  72.  
    for imagePath in imagePaths:
  73.  
    image = cv2.imread(imagePath)
  74.  
    if image is None:
  75.  
    print(imagePath)
  76.  
    continue
  77.  
    if image.shape[2]==4:
  78.  
    image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
  79.  
    image = cv2.resize(image, img_size)
  80.  
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  81.  
    image = img_to_array(image)
  82.  
    datas.append(image)
  83.  
     
  84.  
    (color_label, category_label) = imagePath.split(os.path.sep)[-2].split("_")
  85.  
    category_labels.append(category_label)
  86.  
    color_labels.append(color_label)
  87.  
     
  88.  
    # scale the raw pixel intensities to the range [0, 1]
  89.  
    datas = np.array(datas, dtype="float") / 255.0
  90.  
    category_labels = np.array(category_labels)
  91.  
    color_labels = np.array(color_labels)
  92.  
    return datas, category_labels, color_labels
  93.  
     
  94.  
    # binarize the labels using scikit-learn's special multi-label
  95.  
    def binarize_multilabels_and_save(labels, path):
  96.  
    mlb = MultiLabelBinarizer()
  97.  
    labels = mlb.fit_transform(labels)
  98.  
    print(labels[:6])
  99.  
    print('labels shape:', labels.shape)
  100.  
    for (i, label) in enumerate(mlb.classes_):
  101.  
    print("{}. {}".format(i + 1, label))
  102.  
    with open(path, "wb") as f:
  103.  
    f.write(pickle.dumps(mlb))
  104.  
    return labels, len(mlb.classes_)
  105.  
     
  106.  
    def binarize_labels_and_save(category_labels, color_labels, category_path, color_path):
  107.  
    category_lb = LabelBinarizer()
  108.  
    color_lb = LabelBinarizer()
  109.  
    category_labels = category_lb.fit_transform(category_labels)
  110.  
    color_labels = color_lb.fit_transform(color_labels)
  111.  
     
  112.  
    # loop over each of the possible class labels and show them
  113.  
    for (i, label) in enumerate(category_lb.classes_):
  114.  
    print("category {}. {}".format(i + 1, label))
  115.  
     
  116.  
    for (i, label) in enumerate(color_lb.classes_):
  117.  
    print("color {}. {}".format(i + 1, label))
  118.  
     
  119.  
    with open(category_path, "wb") as f:
  120.  
    f.write(pickle.dumps(category_lb))
  121.  
     
  122.  
    with open(color_path, "wb") as f:
  123.  
    f.write(pickle.dumps(color_lb))
  124.  
    return category_labels, color_labels, len(category_lb.classes_), len(color_lb.classes_)
  125.  
     
  126.  
    # model_type='SimpleNet' 'SmallerInceptionNet'
  127.  
    def train_model(datas, labels, classes, finalAct='sigmoid', model_type='SimpleNet'):
  128.  
    EPOCHS = 20
  129.  
    INIT_LR = 1e-3
  130.  
    BATCH_SIZE = 32
  131.  
    INPUT_SHAPE = (96, 96, 3)
  132.  
    (trainX, testX, trainY, testY) = train_test_split(datas, labels, test_size=0.2, random_state=42)
  133.  
    if model_type == 'SimpleNet':
  134.  
    simpleNet = SimpleNet(INPUT_SHAPE, classes, finalAct)
  135.  
    model = simpleNet.build_model()
  136.  
    else:
  137.  
    smallerInceptionNet = SmallerInceptionNet()
  138.  
    model = smallerInceptionNet.build_model(INPUT_SHAPE, classes, finalAct)
  139.  
     
  140.  
    opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
  141.  
    model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])
  142.  
     
  143.  
    history = model.fit(trainX, trainY, batch_size=BATCH_SIZE,
  144.  
    epochs=EPOCHS, verbose=1,
  145.  
    validation_data=(testX,testY))
  146.  
     
  147.  
    model.save('trained_mode/' + '{}.h5'.format(model_type))
  148.  
     
  149.  
    def train_fashionnet_model(datas, category_labels, color_labels, category_classes, color_classes, finalAct='softmaxt'):
  150.  
    EPOCHS = 30
  151.  
    INIT_LR = 1e-3
  152.  
    BATCH_SIZE = 32
  153.  
    INPUT_SHAPE = (96, 96, 3)
  154.  
    (trainX, testX, trainCategoryY, testCategoryY, trainColorY, testColorY) = train_test_split(datas, category_labels, color_labels, test_size=0.2, random_state=42)
  155.  
     
  156.  
    fashionNet = FashionNet(INPUT_SHAPE, category_classes=category_classes,
  157.  
    color_classes=color_classes, finalAct=finalAct)
  158.  
    model = fashionNet.build_model()
  159.  
    losses = { 'category_output':'categorical_crossentropy', 'color_output':'categorical_crossentropy' }
  160.  
    loss_weights = {'category_output':1.0, 'color_output':1.0}
  161.  
     
  162.  
    opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
  163.  
    model.compile(optimizer=opt,loss=losses, loss_weights=loss_weights, metrics=["accuracy"])
  164.  
     
  165.  
    history = model.fit(trainX, {'category_output': trainCategoryY, 'color_output':trainColorY},
  166.  
    batch_size=BATCH_SIZE, epochs=EPOCHS,
  167.  
    verbose=1,
  168.  
    validation_data=(testX, {'category_output': testCategoryY, 'color_output':testColorY}))
  169.  
     
  170.  
    model.save('trained_mode/' + '{}.h5'.format('FashionNet'))
  171.  
     
  172.  
    plot_fashionnet_loss_acc(history, EPOCHS)
  173.  
     
  174.  
    def plot_loss_acc(history, EPOCHS):
  175.  
    plt.style.use("ggplot")
  176.  
    plt.figure()
  177.  
    N = EPOCHS
  178.  
    plt.plot(np.arange(0, N), history.history["loss"], label="train_loss")
  179.  
    plt.plot(np.arange(0, N), history.history["val_loss"], label="val_loss")
  180.  
    plt.plot(np.arange(0, N), history.history["acc"], label="train_acc")
  181.  
    plt.plot(np.arange(0, N), history.history["val_acc"], label="val_acc")
  182.  
    plt.title("Training Loss and Accuracy")
  183.  
    plt.xlabel("Epoch #")
  184.  
    plt.ylabel("Loss/Accuracy")
  185.  
    plt.legend(loc="upper left")
  186.  
    plt.savefig('plot_loss_acc.png')
  187.  
     
  188.  
    def plot_fashionnet_loss_acc(history, EPOCHS):
  189.  
    loss_names = ['loss', 'category_output_loss', 'color_output_loss']
  190.  
    plt.style.use("ggplot")
  191.  
    (fig, ax) = plt.subplots(3, 1, figsize=(13, 13))
  192.  
     
  193.  
    for (i, l) in enumerate(loss_names):
  194.  
    title = 'Loss for {}'.format(l) if l != 'loss' else 'Total loss'
  195.  
    ax[i].set_title(title)
  196.  
    ax[i].set_xlabel('Epoch #')
  197.  
    ax[i].set_ylabel('Loss')
  198.  
    ax[i].plot(np.arange(0, EPOCHS), history.history[l], label=l)
  199.  
    ax[i].plot(np.arange(0, EPOCHS), history.history["val_"+l], label="val_"+l)
  200.  
    ax[i].legend()
  201.  
    plt.savefig('plot_fashionnet_losses.png')
  202.  
    plt.close()
  203.  
    '''
  204.  
    accuray_names = ['category_output_acc', 'color_output_acc']
  205.  
    plt.style.use("ggplot")
  206.  
    (fig, ax) = plt.subplots(2, 1, figsize=(8, 8))
  207.  
    for (i, l) in enumerate(accuray_names):
  208.  
    title = 'Accuray for {}'.format(l)
  209.  
    ax[i].set_title(title)
  210.  
    ax[i].set_xlabel('Epoch #')
  211.  
    ax[i].set_ylabel('Accuray')
  212.  
    ax[i].plot(np.arange(0, EPOCHS), history.history[l], label=l)
  213.  
    ax[i].plot(np.arange(0, EPOCHS), history.history["val_"+l], label="val_"+l)
  214.  
    ax[i].legend()
  215.  
    plt.savefig('plot_fashionnet_accs.png')
  216.  
    plt.close()
  217.  
    '''
  218.  
     
  219.  
    def main():
  220.  
    data_dir = './dataset'
  221.  
    img_size = (96, 96)
  222.  
    label_dir = './labels'
  223.  
    if not os.path.exists(label_dir):
  224.  
    os.mkdir(label_dir)
  225.  
     
  226.  
    '''
  227.  
    datas, labels = load_data(data_dir, img_size)
  228.  
    labels, classes= binarize_multilabels_and_save(labels, os.path.join(label_dir, 'multi-label.pickle'))
  229.  
    train_model(datas, labels, classes, finalAct='sigmoid', model_type='SimpleNet')
  230.  
     
  231.  
    '''
  232.  
    datas, category_labels, color_labels = load_data_multilabels(data_dir, img_size)
  233.  
    category_path = os.path.join(label_dir, 'category.pickle')
  234.  
    color_path = os.path.join(label_dir, 'color.pickle')
  235.  
    category_labels, color_labels, category_classes, color_classes = binarize_labels_and_save(category_labels, color_labels, category_path, color_path)
  236.  
    train_fashionnet_model(datas, category_labels, color_labels, category_classes, color_classes, finalAct='softmax')
  237.  
     
  238.  
    if __name__ == '__main__':
  239.  
    main()

4、测试部分代码


 
  1.  
    # import the necessary packages
  2.  
    from keras.preprocessing.image import img_to_array
  3.  
    from keras.models import load_model
  4.  
    import numpy as np
  5.  
    import argparse
  6.  
    import imutils
  7.  
    import pickle
  8.  
    import cv2
  9.  
    import os
  10.  
    import tensorflow as tf
  11.  
     
  12.  
    # load the image
  13.  
    # model_type = None, FashionNnet
  14.  
    def load_image(img_path, model_type=None):
  15.  
    image = cv2.imread(img_path)
  16.  
    output = imutils.resize(image, width=400)
  17.  
    if model_type == 'FashionNnet':
  18.  
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  19.  
    # pre-process the image for classification
  20.  
    image = cv2.resize(image, (96, 96))
  21.  
    image = image.astype("float") / 255.0
  22.  
    image = img_to_array(image)
  23.  
    image = np.expand_dims(image, axis=0)
  24.  
    return image, output
  25.  
     
  26.  
    def load_trained_model(img, model_path, labelbin_path):
  27.  
    label_lb = pickle.loads(open(labelbin_path, "rb").read())
  28.  
    model = load_model(model_path)
  29.  
    proba = model.predict(img)[0]
  30.  
     
  31.  
    idxs = np.argsort(proba)[::-1][:2]
  32.  
    label_1 = label_lb.classes_[idxs[0]]
  33.  
    label_2 = label_lb.classes_[idxs[1]]
  34.  
     
  35.  
    proba_1 = proba[idxs[0]]
  36.  
    proba_2 = proba[idxs[1]]
  37.  
     
  38.  
    result = (label_1, proba_1, label_2, proba_2)
  39.  
    return result
  40.  
     
  41.  
     
  42.  
     
  43.  
    # load the trained convolutional neural network
  44.  
    def load_trained_fashionnet_model(img, model_path, categorybin_path, colorbin_path):
  45.  
    category_lb = pickle.loads(open(categorybin_path, "rb").read())
  46.  
    color_lb = pickle.loads(open(colorbin_path, "rb").read())
  47.  
     
  48.  
    model = load_model(model_path, custom_objects={'tf':tf})
  49.  
    (category_proba, color_proba) = model.predict(img)
  50.  
     
  51.  
    category_idx = category_proba[0].argmax()
  52.  
    color_idx = color_proba[0].argmax()
  53.  
    category_label = category_lb.classes_[category_idx]
  54.  
    color_label = color_lb.classes_[color_idx]
  55.  
     
  56.  
    category_proba = category_proba[0][category_idx]
  57.  
    color_proba = color_proba[0][color_idx]
  58.  
    result = (category_label, category_proba, color_label, color_proba)
  59.  
    return result
  60.  
     
  61.  
    def show_result(img, result):
  62.  
    (label_1, proba_1, label_2, proba_2) = result
  63.  
    text1 = "{}: {:.2f}%".format(label_1, proba_1*100)
  64.  
    text2 = "{}: {:.2f}%".format(label_2, proba_2*100)
  65.  
     
  66.  
    cv2.putText(img, text1, (10, 25),
  67.  
    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
  68.  
    cv2.putText(img, text2, (10, 55),
  69.  
    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
  70.  
     
  71.  
    # show the output image
  72.  
    cv2.imshow("Output", img)
  73.  
    cv2.waitKey(2000)
  74.  
    cv2.destroyAllWindows()
  75.  
     
  76.  
    def show_fashionnet_result(img, result):
  77.  
    (category_label, category_proba, color_label, color_proba) = result
  78.  
    category_text = "category: {}: {:.2f}%".format(category_label, category_proba*100)
  79.  
    color_text = "color: {}: {:.2f}%".format(color_label, color_proba*100)
  80.  
     
  81.  
    cv2.putText(img, category_text, (10, 25),
  82.  
    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
  83.  
    cv2.putText(img, color_text, (10, 55),
  84.  
    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
  85.  
     
  86.  
    # show the output image
  87.  
    cv2.imshow("Output", img)
  88.  
    cv2.waitKey(2000)
  89.  
    cv2.destroyAllWindows()
  90.  
     
  91.  
     
  92.  
    if __name__=='__main__':
  93.  
    test_dir = './examples'
  94.  
    #model_type = 'FashionNnet'
  95.  
    model_type = None
  96.  
    for img in os.listdir(test_dir):
  97.  
    img_path = os.path.join(test_dir, img)
  98.  
    if model_type == None:
  99.  
    image,output = load_image(img_path)
  100.  
    model_path = 'trained_mode/SimpleNet.h5'
  101.  
    labelbin_path = './labels/multi-label.pickle'
  102.  
    result = load_trained_model(image, model_path, labelbin_path)
  103.  
    show_result(output, result)
  104.  
    elif model_type == 'FashionNnet':
  105.  
    image, output = load_image(img_path, model_type)
  106.  
    model_path = 'trained_mode/FashionNet.h5'
  107.  
    categorybin_path = './labels/category.pickle'
  108.  
    colorbin_path = './labels/color.pickle'
  109.  
     
  110.  
    result = load_trained_fashionnet_model(image, model_path, categorybin_path, colorbin_path)
  111.  
    show_fashionnet_result(output, result)
  112.  
     

5、数据和详细完整代码

代码地址:https://github.com/zhangwei147258/fashion_mutil_label_classifier_keras

数据地址:https://pan.baidu.com/s/11LoY2H5shADwiQwPuhB6ng 提取码:pg7d 


多标签分类用keras 分类Multi-label classification with Keras
深度学习多标签分类(基于Keras)
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。