美文网首页
仿照 keras cifar10.load_data() 加载

仿照 keras cifar10.load_data() 加载

作者: 谢小帅 | 来源:发表于2019-03-07 14:16 被阅读0次

SUN.py

import numpy as np
from glob import glob
import os
import skimage.transform
import skimage.io
import random
import matplotlib.colors

image_h = 480
image_w = 640


class Resize(object):
    def __call__(self, sample):
        image, depth, label = sample['image'], sample['depth'], sample['label']
        # Bi-linear
        image = skimage.transform.resize(image, (image_h, image_w), order=1,  # 480,640
                                         mode='reflect', preserve_range=True)
        # Nearest-neighbor
        depth = skimage.transform.resize(depth, (image_h, image_w), order=0,
                                         mode='reflect', preserve_range=True)  # keep ori img value range
        label = skimage.transform.resize(label, (image_h, image_w), order=0,
                                         mode='reflect', preserve_range=True)

        return {'image': image, 'depth': depth, 'label': label}


class RandomScale(object):
    def __init__(self, scale):  # eg: scale = (1.0, 1.4)
        self.scale_low = min(scale)
        self.scale_high = max(scale)

    def __call__(self, sample):
        image, depth, label = sample['image'], sample['depth'], sample['label']

        target_scale = random.uniform(self.scale_low, self.scale_high)  # random choose a val in this range
        # (H, W, C)
        target_height = int(round(target_scale * image.shape[0]))
        target_width = int(round(target_scale * image.shape[1]))
        # Bi-linear
        image = skimage.transform.resize(image, (target_height, target_width),
                                         order=1, mode='reflect', preserve_range=True)
        # Nearest-neighbor
        depth = skimage.transform.resize(depth, (target_height, target_width),
                                         order=0, mode='reflect', preserve_range=True)
        label = skimage.transform.resize(label, (target_height, target_width),
                                         order=0, mode='reflect', preserve_range=True)

        return {'image': image, 'depth': depth, 'label': label}


class RandomHSV(object):
    def __init__(self, h_range, s_range, v_range):
        assert isinstance(h_range, (list, tuple)) and \
               isinstance(s_range, (list, tuple)) and \
               isinstance(v_range, (list, tuple))
        self.h_range = h_range
        self.s_range = s_range
        self.v_range = v_range

    def __call__(self, sample):
        img = sample['image']
        img_hsv = matplotlib.colors.rgb_to_hsv(img)
        img_h, img_s, img_v = img_hsv[:, :, 0], img_hsv[:, :, 1], img_hsv[:, :, 2]
        h_random = np.random.uniform(min(self.h_range), max(self.h_range))
        s_random = np.random.uniform(min(self.s_range), max(self.s_range))
        v_random = np.random.uniform(-min(self.v_range), max(self.v_range))
        img_h = np.clip(img_h * h_random, 0, 1)
        img_s = np.clip(img_s * s_random, 0, 1)
        img_v = np.clip(img_v + v_random, 0, 255)
        img_hsv = np.stack([img_h, img_s, img_v], axis=2)
        img_new = matplotlib.colors.hsv_to_rgb(img_hsv)
        sample['image'] = img_new

        return sample


class RandomCrop(object):
    def __init__(self, th, tw):
        self.th = th
        self.tw = tw

    def __call__(self, sample):
        image, depth, label = sample['image'], sample['depth'], sample['label']
        h = image.shape[0]
        w = image.shape[1]
        i = random.randint(0, h - self.th)
        j = random.randint(0, w - self.tw)

        return {'image': image[i:i + image_h, j:j + image_w, :],
                'depth': depth[i:i + image_h, j:j + image_w],
                'label': label[i:i + image_h, j:j + image_w]}


class RandomFlip(object):
    def __call__(self, sample):
        image, depth, label = sample['image'], sample['depth'], sample['label']
        if random.random() > 0.5:  # 0.5 chance flip
            image = np.fliplr(image).copy()
            depth = np.fliplr(depth).copy()
            label = np.fliplr(label).copy()

        return {'image': image, 'depth': depth, 'label': label}


class Normalize(object):
    def __call__(self, sample):
        image, depth = sample['image'], sample['depth']
        image = image / 255
        image = np.transpose(image, [2, 0, 1])  # channel first
        depth = np.expand_dims(depth, 0)  # for normalize shape[0]
        # torchvision.transforms.Normalize
        sample['image'] = transforms_normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        sample['depth'] = transforms_normalize(depth, mean=[19050], std=[9650])

        return sample


def transforms_normalize(data, mean, std):  # fake torchvision.transforms.Normalize
    assert data.shape[0] == len(mean) and len(mean) == len(std)
    for i in range(len(mean)):
        a = data[i]
        a = a * std[i] / np.std(a)  # scale std to change std
        a = a + (mean[i] - np.mean(a))  # shift mean, not change std
        data[i] = a
    return data


class ToRGBD(object):
    def __call__(self, sample):
        image, depth = sample['image'], sample['depth']
        rgbd = np.concatenate((image, depth), axis=0)
        return {'rgbd': rgbd, 'label': sample['label']}


def preprocess(sample):
    sample = Resize()(sample)
    sample = RandomScale(scale=(1.0, 1.4))(sample)
    sample = RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25))(sample)
    sample = RandomCrop(image_h, image_w)(sample)
    sample = RandomFlip()(sample)
    sample = Normalize()(sample)
    sample = ToRGBD()(sample)
    return sample


def load_data(phase='train'):
    if phase == 'train':
        image_dir = r'E:\Dataset\sun\train\image'
        depth_dir = r'E:\Dataset\sun\train\depth'
        label_dir = r'E:\Dataset\sun\train\label_npy'  # npy
        num_samples = 8000  # memory error
    elif phase == 'val':
        image_dir = r'E:\Dataset\sun\val\image'
        depth_dir = r'E:\Dataset\sun\val\depth'
        label_dir = r'E:\Dataset\sun\val\label_npy'  # npy
        num_samples = 3000
    else:
        image_dir = r'E:\Dataset\sun\test\image'
        depth_dir = r'E:\Dataset\sun\test\depth'
        label_dir = r'E:\Dataset\sun\test\label_npy'  # npy
        num_samples = 2335

    label_list = glob(os.path.join(label_dir, '*.npy'))
    image_name = [x.split('\\')[-1].split('.')[0] for x in label_list]  # all images basename, windows 反斜杠
    image_list = [os.path.join(image_dir, x + '.jpg') for x in image_name]
    depth_list = [os.path.join(depth_dir, x + '.png') for x in image_name]

    x_data = np.empty((num_samples, 4, image_h, image_w))  # rgbd
    y_data = np.empty((num_samples, image_h, image_w))  # label

    for i in range(num_samples):
        # create one sample
        sample = {
            'image': skimage.io.imread(image_list[i]),
            'depth': skimage.io.imread(depth_list[i]),
            'label': np.load(label_list[i])
        }
        # preprocess sample
        sample = preprocess(sample)
        x_data[i] = sample['rgbd']
        x_data[i] = sample['label']
    return x_data, y_data


if __name__ == '__main__':
    x_train, y_train = load_data(phase='train')
    print(x_train.shape)
    print(y_train.shape)

相关文章

网友评论

      本文标题:仿照 keras cifar10.load_data() 加载

      本文链接:https://www.haomeiwen.com/subject/ueycpqtx.html