这篇教程Generative Image Inpainting with Contextual Attention写得很实用,希望能帮到您。
Generative Image Inpainting with Contextual Attention
今天介绍CVPR 2018的Generative Image Inpainting with Contextual Attention
paper: https://arxiv.org/abs/1801.07892, demo http://jiahuiyu.com/deepfill
github:https://github.com/JiahuiYu/generative_inpainting
先看效果:
上述是作者修复的结果,我自己训练后修复的如下:
这里生成了两个不同情况的图,因为使用了两个不同的pre-train Model
下面介绍如何使用:
- Requirements:
- Install python3.
- Install tensorflow (tested on Release 1.3.0, 1.4.0, 1.5.0, 1.6.0, 1.7.0).
- Install tensorflow toolkit neuralgym (run
pip install git+https://github.com/JiahuiYu/neuralgym ).
- Training:
- Prepare training images filelist and shuffle it (example).
- Modify inpaint.yml to set DATA_FLIST, LOG_DIR, IMG_SHAPES and other parameters.
- Run
python3 train.py .
这里重点介绍如何准备自己的训练集,直接写了个python脚本自动处理即可。gen_flist.py自动将源数据集划分为训练集和验证集。并生成项目需要的格式。
-
-
-
-
-
-
-
-
-
-
parser = argparse.ArgumentParser()
-
parser.add_argument('--folder_path', default='/home/gavin/Dataset/celeba', type=str,
-
-
parser.add_argument('--train_filename', default='./data/celeba/train_shuffled.flist', type=str,
-
help='The train filename.')
-
parser.add_argument('--validation_filename', default='./data/celeba/validation_static_view.flist', type=str,
-
help='The validation filename.')
-
-
-
def _get_filenames(dataset_dir):
-
-
image_list = os.listdir(dataset_dir)
-
photo_filenames = [os.path.join(dataset_dir, _) for _ in image_list]
-
-
-
-
if __name__ == "__main__":
-
-
args = parser.parse_args()
-
-
data_dir = args.folder_path
-
-
-
photo_filenames = _get_filenames(data_dir)
-
print("size of celeba is %d" % (len(photo_filenames)))
-
-
-
-
random.shuffle(photo_filenames)
-
training_file_names = photo_filenames[_NUM_TEST:]
-
validation_file_names = photo_filenames[:_NUM_TEST]
-
-
print("training file size:",len(training_file_names))
-
print("validation file size:", len(validation_file_names))
-
-
-
if not os.path.exists(args.train_filename):
-
os.mknod(args.train_filename)
-
-
if not os.path.exists(args.validation_filename):
-
os.mknod(args.validation_filename)
-
-
-
fo = open(args.train_filename, "w")
-
fo.write("\n".join(training_file_names))
-
-
-
fo = open(args.validation_filename, "w")
-
fo.write("\n".join(validation_file_names))
-
-
-
-
print("Written file is: ", args.train_filename)
-
-
-
最终生成的格式如下图:
-
-
- Resume training:
- Modify MODEL_RESTORE flag in inpaint.yml. E.g., MODEL_RESTORE: 20180115220926508503_places2_model.
- Run
python3 train.py .
- Testing:
- Run
python test.py --image examples/input.png --mask examples/mask.png --output examples/output.png --checkpoint model_logs/your_model_dir .
大概就是以上操作,后面贴上我实际训练和测试的脚本。
配置文件
其中inpaint.yml中要注意的是,在恢复训练模型的时候,MODEL_RESTORE的值:
多GPU模式训练
如果使用多个GPU训练,需要改三处地方,分别是inpaint.yml中两处,如下
-
-
-
分别指定将gpu使用的个数及各自的id,第三处,也是最重要而且特别容易忽略的,在train.py中修改这里
-
-
-
trainer = ng.train.Trainer(
-
-
-
max_iters=config.MAX_ITERS,
-
graph_def=multigpu_graph_def,
-
grads_summary=config.GRADS_SUMMARY,
-
gradient_processor=gradient_processor,
-
-
'model': model, 'data': data, 'config': config, 'loss_type': 'g'},
-
-
-
-
-
-
trainer = ng.train.MultiGPUTrainer(
-
-
-
max_iters=config.MAX_ITERS,
-
graph_def=multigpu_graph_def,
-
grads_summary=config.GRADS_SUMMARY,
-
gradient_processor=gradient_processor,
-
-
'model': model, 'data': data, 'config': config, 'loss_type': 'g'},
-
-
-
num_gpus = config.NUM_GPUS,
-
-
即有两种调用方式,一种单GPU跑,一种多GPU模式,而多GPU模式下需要加上参数
num_gpus = config.NUM_GPUS,
脚本:
-
-
-
-
-
Modify MODEL_RESTORE flag in inpaint.yml. E.g., MODEL_RESTORE: 20180115220926508503_places2_model.
-
-
-
-
-
python3 test.py --image examples/input.png --mask examples/mask.png --output examples/output.png --checkpoint model_logs/your_model_dir.
-
-
python3 test.py --image examples/celeba/celebahr_patches_164787_input.png --mask examples/center_mask_256.png
-
--output examples/output_celeba.png --checkpoint_dir model_logs/celebA_model/snap-60000
-
-
-
-
-
1. python3 generate_mask.py --img ./examples/celeba/000035.jpg --HEIGHT 64 --WIDTH 64
-
-
2. python3 test.py --image ./data/mask_img/masked/000035.jpg --mask ./data/mask_img/mask/000035.jpg \
-
--output examples/output_000035.png --checkpoint_dir model_logs/celebA_model/snap-90000
测试
实际测试过程中,对于任一张图,需要输入mask,和input,这里需要我们自己生成,为了便于随机生成mask,我写了如下代码,可以随机生成规则及不规则的mask
-
-
-
-
-
-
-
-
-
from copy import deepcopy
-
from random import randint
-
-
-
-
-
-
-
-
-
parser = argparse.ArgumentParser()
-
parser.add_argument('--img', default='./examples/celeba/000042.jpg', type=str,
-
help='The input img for single image ')
-
-
parser.add_argument('--input_dirimg', default='./data/mask_img/src_img/', type=str,
-
help='The input folder path for multi-images')
-
parser.add_argument('--output_dirmask', default='./data/mask_img/mask/', type=str,
-
help='The output file path of mask.')
-
parser.add_argument('--output_dirmasked', default='./data/mask_img/masked/', type=str,
-
help='The output file path of masked.')
-
parser.add_argument('--MAX_MASK_NUMS', default='16', type=int,
-
help='max numbers of masks')
-
-
parser.add_argument('--MAX_DELTA_HEIGHT', default='32', type=int,
-
help='max height of delta')
-
parser.add_argument('--MAX_DELTA_WIDTH', default='32', type=int,
-
help='max width of delta')
-
-
parser.add_argument('--HEIGHT', default='128', type=int,
-
help='max height of delta')
-
parser.add_argument('--WIDTH', default='128', type=int,
-
help='max width of delta')
-
-
parser.add_argument('--IMG_SHAPES', type=eval, default=(256, 256, 3))
-
-
-
-
def random_mask(height, width, config,channels=3):
-
"""Generates a random irregular mask with lines, circles and elipses"""
-
img = np.zeros((height, width, channels), np.uint8)
-
-
-
size = int((width + height) * 0.02)
-
if width < 64 or height < 64:
-
raise Exception("Width and Height of mask must be at least 64!")
-
-
-
for _ in range(randint(1, config.MAX_MASK_NUMS)):
-
x1, x2 = randint(1, width), randint(1, width)
-
y1, y2 = randint(1, height), randint(1, height)
-
thickness = randint(3, size)
-
cv2.line(img, (x1, y1), (x2, y2), (1, 1, 1), thickness)
-
-
-
-
for _ in range(randint(1, config.MAX_MASK_NUMS)):
-
x1, y1 = randint(1, width), randint(1, height)
-
radius = randint(3, size)
-
cv2.circle(img, (x1, y1), radius, (1, 1, 1), -1)
-
-
-
for _ in range(randint(1, config.MAX_MASK_NUMS)):
-
x1, y1 = randint(1, width), randint(1, height)
-
s1, s2 = randint(1, width), randint(1, height)
-
a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)
-
thickness = randint(3, size)
-
cv2.ellipse(img, (x1, y1), (s1, s2), a1, a2, a3, (1, 1, 1), thickness)
-
-
-
-
-
-
-
# %matplotlib inline ==> plt.show()
-
-
_, axes = plt.subplots(5, 5, figsize=(20, 20))
-
axes = list(itertools.chain.from_iterable(axes))
-
-
for i in range(len(axes)):
-
-
img = random_mask(500, 500)
-
-
-
axes[i].imshow(img * 255)
-
-
-
-
-
-
-
-
"""Generate a random tlhw with configuration.
-
-
-
config: Config should have configuration including IMG_SHAPES,
-
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
-
-
-
tuple: (top, left, height, width)
-
-
-
img_shape = config.IMG_SHAPES
-
img_height = img_shape[0]
-
-
maxt = img_height - config.HEIGHT
-
maxl = img_width - config.WIDTH
-
-
[], minval=0, maxval=maxt, dtype=tf.int32)
-
-
[], minval=0, maxval=maxl, dtype=tf.int32)
-
h = tf.constant(config.HEIGHT)
-
w = tf.constant(config.WIDTH)
-
-
-
def bbox2mask(bbox, config, name='mask'):
-
"""Generate mask tensor from bbox.
-
-
-
bbox: configuration tuple, (top, left, height, width)
-
config: Config should have configuration including IMG_SHAPES,
-
MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.
-
-
-
tf.Tensor: output with shape [1, H, W, 1]
-
-
-
def npmask(bbox, height, width, delta_h, delta_w):
-
mask = np.zeros((1, height, width, 1), np.float32)
-
h = np.random.randint(delta_h//2+1)
-
w = np.random.randint(delta_w//2+1)
-
mask[:, bbox[0]+h:bbox[0]+bbox[2]-h,
-
bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1.
-
-
with tf.variable_scope(name), tf.device('/cpu:0'):
-
img_shape = config.IMG_SHAPES
-
-
-
-
-
-
config.MAX_DELTA_HEIGHT, config.MAX_DELTA_WIDTH],
-
tf.float32, stateful=False)
-
mask.set_shape([1] + [height, width] + [1])
-
-
-
-
def random_mask_rect(img_path,config,bsave=True):
-
-
-
img_data = cv2.imread(img_path)
-
-
-
-
-
# generate mask, 1 represents masked point
-
bbox = random_bbox(config)
-
mask = bbox2mask(bbox, config, name='mask_c')
-
img_pos = img_data / 127.5 - 1.
-
masked_img = img_pos * (1. - mask)
-
-
-
-
img_shape = config.IMG_SHAPES
-
img_height = img_shape[0]
-
-
-
image = cv2.resize(img_data, (img_width, img_height))
-
rectangle = np.zeros(image.shape[0:2], dtype=np.uint8)
-
-
maxt = img_height - config.HEIGHT
-
maxl = img_width - config.WIDTH
-
-
-
-
-
-
-
-
mask = cv2.rectangle(rectangle,(x, y), (x+w, y+h) , 255, -1)
-
-
masked_img = deepcopy(image)
-
masked_img[mask == 255] = 255
-
-
-
print("shape of mask:",mask.shape)
-
print("shape of masked_img:",masked_img.shape)
-
-
-
save_name_mask = os.path.join(config.output_dirmask, img_path.split('/')[-1])
-
cv2.imwrite(save_name_mask,mask)
-
-
save_name_masked = os.path.join(config.output_dirmasked, img_path.split('/')[-1])
-
cv2.imwrite(save_name_masked, masked_img)
-
-
-
-
-
-
if not os.path.exists(config.input_dirimg):
-
os.mkdir(config.input_dirimg)
-
if not os.path.exists(config.output_dirmask):
-
os.mkdir(config.output_dirmask)
-
if not os.path.exists(config.output_dirmasked):
-
os.mkdir(config.output_dirmasked)
-
-
-
-
-
def load_mask(img_path,config,bsave=False):
-
-
-
img = cv2.imread(img_path)
-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-
-
print("Shape of image is: ",shape)
-
-
mask = random_mask(shape[0], shape[1],config)
-
-
-
masked_img = deepcopy(img)
-
masked_img[mask == 0] = 255
-
-
-
-
-
save_name_mask = os.path.join(config.output_dirmask, img_path.split('/')[-1])
-
cv2.imwrite(save_name_mask,mask)
-
-
save_name_masked = os.path.join(config.output_dirmasked, img_path.split('/')[-1])
-
cv2.imwrite(save_name_masked, masked_img)
-
-
-
-
-
-
-
def img2maskedImg(dataset_dir):
-
-
image_list = os.listdir(dataset_dir)
-
files = [os.path.join(dataset_dir, _) for _ in image_list]
-
-
for index,jpg in enumerate(files):
-
-
sys.stdout.write('\r>>Converting image %d/%d ' % (index,length))
-
-
load_mask(jpg,config,True)
-
-
-
-
print('could not read:',jpg)
-
-
-
-
sys.stdout.write('Convert Over!\n')
-
-
-
-
-
if __name__ == '__main__':
-
config = parser.parse_args()
-
-
-
-
-
-
-
-
-
-
-
-
masked_img, mask = random_mask_rect(img,config)
-
-
-
-
_, axes = plt.subplots(1, 3, figsize=(20, 5))
-
-
-
axes[2].imshow(masked_img)
-
-
-
-
效果:
mask,masked,output
什么是掩膜(mask)图像mask制作实例 联合模型适应小数据集的深度学习模型:Union-net: A deep neural network model adapted to small data sets |