针对模型生成图片存在锯齿和羽化问题的处理方案
import paddlex as pdx
import cv2
import numpy as np
import matplotlib.pyplot as plt
# 模型地址:https://bj.bcebos.com/paddlex/examples/human_seg/models/humanseg_server_params.tar
#model = pdx.load_model('C:/Users/bigdata/Downloads/humanseg_server_params')
model = pdx.load_model('d:/tmp/models/Xception65/best_model/')
image_name = 'd:/tmp/002.jpg'
result = model.predict(image_name)
pdx.seg.visualize(image_name, result, weight=0.0, save_dir='output')
#score_map = (result['score_map'][:, :, 1] * 255).astype('uint8') #将背景部分score转换为0-244
def process(logit, shape):
thresh = 120
logit = logit * 255
logit = cv2.resize(logit, shape)
logit -= thresh
logit[logit < 0] = 0
logit = 255 * logit / (255 - thresh)
return logit.astype('uint8')
im = cv2.imread(image_name)
rows, cols, _ = im.shape
score_map = process(result['score_map'][:, :, 1], (cols, rows))
score_map = np.expand_dims(score_map, axis=2)
rgba = np.concatenate((im, score_map), axis=2) #将3通道变为4通道
plt.imshow(rgba)
plt.show()
#cv2.imwrite('d:/output/tmp.png',rgba)
网友评论