import
cv2
import
matplotlib.pyplot as plt
import
numpy as np
from
keras
import
backend as K
from
keras.preprocessing
import
image
def
heatmap(model, data_img, layer_idx, img_show
=
None
, pred_idx
=
None
):
if
data_img.shape.__len__() !
=
4
:
if
img_show
is
None
:
img_show
=
data_img
input_shape
=
K.int_shape(model.
input
)[
1
:
3
]
data_img
=
image.img_to_array(image.array_to_img(data_img).resize(input_shape))
data_img
=
np.expand_dims(data_img, axis
=
0
)
if
pred_idx
is
None
:
preds
=
model.predict(data_img)
pred_idx
=
np.argmax(preds[
0
])
target_output
=
model.output[:, pred_idx]
last_conv_layer_output
=
model.layers[layer_idx].output
grads
=
K.gradients(target_output, last_conv_layer_output)[
0
]
pooled_grads
=
K.mean(grads, axis
=
(
0
,
1
,
2
))
iterate
=
K.function([model.
input
], [pooled_grads, last_conv_layer_output[
0
]])
pooled_grads_value, conv_layer_output_value
=
iterate([data_img])
for
i
in
range
(conv_layer_output_value.shape[
-
1
]):
conv_layer_output_value[:, :, i]
*
=
pooled_grads_value[i]
heatmap
=
np.mean(conv_layer_output_value, axis
=
-
1
)
heatmap
=
np.maximum(heatmap,
0
)
heatmap
/
=
np.
max
(heatmap)
heatmap
=
cv2.resize(heatmap, (img_show.shape[
1
], img_show.shape[
0
]))
heatmap
=
np.uint8(
255
*
heatmap)
superimposed_img
=
img_show
+
cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)[:,:,::
-
1
]
*
0.4
superimposed_img
=
np.minimum(superimposed_img,
255
).astype(
'uint8'
)
return
superimposed_img, heatmap
def
heatmaps(model, data_img, img_show
=
None
):
if
img_show
is
None
:
img_show
=
np.array(data_img)
input_shape
=
K.int_shape(model.
input
)[
1
:
3
]
data_img
=
image.img_to_array(image.array_to_img(data_img).resize(input_shape))
data_img
=
np.expand_dims(data_img, axis
=
0
)
preds
=
model.predict(data_img)
pred_idx
=
np.argmax(preds[
0
])
print
(
"预测为:%d(%f)"
%
(pred_idx, preds[
0
][pred_idx]))
indexs
=
[]
for
i
in
range
(model.layers.__len__()):
if
'conv'
in
model.layers[i].name:
indexs.append(i)
print
(
'模型共有%d个卷积层'
%
indexs.__len__())
plt.suptitle(
'heatmaps for each conv'
)
for
i
in
range
(indexs.__len__()):
ret
=
heatmap(model, data_img, indexs[i], img_show
=
img_show, pred_idx
=
pred_idx)
plt.subplot(np.ceil(np.sqrt(indexs.__len__()
*
2
)), np.ceil(np.sqrt(indexs.__len__()
*
2
)), i
*
2
+
1
)\
.set_title(model.layers[indexs[i]].name)
plt.imshow(ret[
0
])
plt.axis(
'off'
)
plt.subplot(np.ceil(np.sqrt(indexs.__len__()
*
2
)), np.ceil(np.sqrt(indexs.__len__()
*
2
)), i
*
2
+
2
)\
.set_title(model.layers[indexs[i]].name)
plt.imshow(ret[
1
])
plt.axis(
'off'
)
plt.show()