仿照 keras cifar10.load_data() 加载

仿照 keras cifar10.load_data() 加载

作者: 谢小帅 | 来源:发表于2019-03-07 14:16 被阅读0次


    import numpy as np
    from glob import glob
    import os
    import skimage.transform
    import skimage.io
    import random
    import matplotlib.colors
    image_h = 480
    image_w = 640
    class Resize(object):
        def __call__(self, sample):
            image, depth, label = sample['image'], sample['depth'], sample['label']
            # Bi-linear
            image = skimage.transform.resize(image, (image_h, image_w), order=1,  # 480,640
                                             mode='reflect', preserve_range=True)
            # Nearest-neighbor
            depth = skimage.transform.resize(depth, (image_h, image_w), order=0,
                                             mode='reflect', preserve_range=True)  # keep ori img value range
            label = skimage.transform.resize(label, (image_h, image_w), order=0,
                                             mode='reflect', preserve_range=True)
            return {'image': image, 'depth': depth, 'label': label}
    class RandomScale(object):
        def __init__(self, scale):  # eg: scale = (1.0, 1.4)
            self.scale_low = min(scale)
            self.scale_high = max(scale)
        def __call__(self, sample):
            image, depth, label = sample['image'], sample['depth'], sample['label']
            target_scale = random.uniform(self.scale_low, self.scale_high)  # random choose a val in this range
            # (H, W, C)
            target_height = int(round(target_scale * image.shape[0]))
            target_width = int(round(target_scale * image.shape[1]))
            # Bi-linear
            image = skimage.transform.resize(image, (target_height, target_width),
                                             order=1, mode='reflect', preserve_range=True)
            # Nearest-neighbor
            depth = skimage.transform.resize(depth, (target_height, target_width),
                                             order=0, mode='reflect', preserve_range=True)
            label = skimage.transform.resize(label, (target_height, target_width),
                                             order=0, mode='reflect', preserve_range=True)
            return {'image': image, 'depth': depth, 'label': label}
    class RandomHSV(object):
        def __init__(self, h_range, s_range, v_range):
            assert isinstance(h_range, (list, tuple)) and \
                   isinstance(s_range, (list, tuple)) and \
                   isinstance(v_range, (list, tuple))
            self.h_range = h_range
            self.s_range = s_range
            self.v_range = v_range
        def __call__(self, sample):
            img = sample['image']
            img_hsv = matplotlib.colors.rgb_to_hsv(img)
            img_h, img_s, img_v = img_hsv[:, :, 0], img_hsv[:, :, 1], img_hsv[:, :, 2]
            h_random = np.random.uniform(min(self.h_range), max(self.h_range))
            s_random = np.random.uniform(min(self.s_range), max(self.s_range))
            v_random = np.random.uniform(-min(self.v_range), max(self.v_range))
            img_h = np.clip(img_h * h_random, 0, 1)
            img_s = np.clip(img_s * s_random, 0, 1)
            img_v = np.clip(img_v + v_random, 0, 255)
            img_hsv = np.stack([img_h, img_s, img_v], axis=2)
            img_new = matplotlib.colors.hsv_to_rgb(img_hsv)
            sample['image'] = img_new
            return sample
    class RandomCrop(object):
        def __init__(self, th, tw):
            self.th = th
            self.tw = tw
        def __call__(self, sample):
            image, depth, label = sample['image'], sample['depth'], sample['label']
            h = image.shape[0]
            w = image.shape[1]
            i = random.randint(0, h - self.th)
            j = random.randint(0, w - self.tw)
            return {'image': image[i:i + image_h, j:j + image_w, :],
                    'depth': depth[i:i + image_h, j:j + image_w],
                    'label': label[i:i + image_h, j:j + image_w]}
    class RandomFlip(object):
        def __call__(self, sample):
            image, depth, label = sample['image'], sample['depth'], sample['label']
            if random.random() > 0.5:  # 0.5 chance flip
                image = np.fliplr(image).copy()
                depth = np.fliplr(depth).copy()
                label = np.fliplr(label).copy()
            return {'image': image, 'depth': depth, 'label': label}
    class Normalize(object):
        def __call__(self, sample):
            image, depth = sample['image'], sample['depth']
            image = image / 255
            image = np.transpose(image, [2, 0, 1])  # channel first
            depth = np.expand_dims(depth, 0)  # for normalize shape[0]
            # torchvision.transforms.Normalize
            sample['image'] = transforms_normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            sample['depth'] = transforms_normalize(depth, mean=[19050], std=[9650])
            return sample
    def transforms_normalize(data, mean, std):  # fake torchvision.transforms.Normalize
        assert data.shape[0] == len(mean) and len(mean) == len(std)
        for i in range(len(mean)):
            a = data[i]
            a = a * std[i] / np.std(a)  # scale std to change std
            a = a + (mean[i] - np.mean(a))  # shift mean, not change std
            data[i] = a
        return data
    class ToRGBD(object):
        def __call__(self, sample):
            image, depth = sample['image'], sample['depth']
            rgbd = np.concatenate((image, depth), axis=0)
            return {'rgbd': rgbd, 'label': sample['label']}
    def preprocess(sample):
        sample = Resize()(sample)
        sample = RandomScale(scale=(1.0, 1.4))(sample)
        sample = RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25))(sample)
        sample = RandomCrop(image_h, image_w)(sample)
        sample = RandomFlip()(sample)
        sample = Normalize()(sample)
        sample = ToRGBD()(sample)
        return sample
    def load_data(phase='train'):
        if phase == 'train':
            image_dir = r'E:\Dataset\sun\train\image'
            depth_dir = r'E:\Dataset\sun\train\depth'
            label_dir = r'E:\Dataset\sun\train\label_npy'  # npy
            num_samples = 8000  # memory error
        elif phase == 'val':
            image_dir = r'E:\Dataset\sun\val\image'
            depth_dir = r'E:\Dataset\sun\val\depth'
            label_dir = r'E:\Dataset\sun\val\label_npy'  # npy
            num_samples = 3000
            image_dir = r'E:\Dataset\sun\test\image'
            depth_dir = r'E:\Dataset\sun\test\depth'
            label_dir = r'E:\Dataset\sun\test\label_npy'  # npy
            num_samples = 2335
        label_list = glob(os.path.join(label_dir, '*.npy'))
        image_name = [x.split('\\')[-1].split('.')[0] for x in label_list]  # all images basename, windows 反斜杠
        image_list = [os.path.join(image_dir, x + '.jpg') for x in image_name]
        depth_list = [os.path.join(depth_dir, x + '.png') for x in image_name]
        x_data = np.empty((num_samples, 4, image_h, image_w))  # rgbd
        y_data = np.empty((num_samples, image_h, image_w))  # label
        for i in range(num_samples):
            # create one sample
            sample = {
                'image': skimage.io.imread(image_list[i]),
                'depth': skimage.io.imread(depth_list[i]),
                'label': np.load(label_list[i])
            # preprocess sample
            sample = preprocess(sample)
            x_data[i] = sample['rgbd']
            x_data[i] = sample['label']
        return x_data, y_data
    if __name__ == '__main__':
        x_train, y_train = load_data(phase='train')



          本文标题:仿照 keras cifar10.load_data() 加载
