美文网首页
2019-09-06 pytorch实现视频数据加载

2019-09-06 pytorch实现视频数据加载

作者: 强扭的解渴瓜 | 来源:发表于2019-09-28 12:46 被阅读0次

小白的自我救赎
书接上回,在action recognition中,我们已经学会了视频数据中帧图像的读取,并将读取到的帧图像保存在文件夹frames_of_video中,今天讲学习如何讲在这些帧图像中,随机读取并加载其中的16帧图像,并保存为张量形式,然后进行随机裁剪等操作,实现pytorch中帧图像加载,以便完成动作识别。
参考文档:
http://pytorch123.com/ThirdSection/DataLoding/
https://www.jianshu.com/p/4ebf2a82017b
这段代码的总体思想就是索引到存储帧图像的文件夹,按照train_list挨个索引到对应的帧图像,随机取初始帧,然后读取连续16帧,存储到4维张量中,格式为[z,c,h,w]
z:帧数=16
c:图像深度
h:图像高度
w:图像宽度
代码如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 26 21:35:32 2019

@author: xuguangying
"""

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pandas as pd
import os
import random

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
print('a')
#%%
class UCF101(Dataset):
    """UCF101 Landmarks dataset."""

    def __init__(self, info_list, root_dir, transform=None):
        """
        Args:
            info_list (string): Path to the info list file with annotations.
            root_dir (string): Directory with all the video frames.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(info_list,delimiter=' ', header=None)
        self.root_dir = root_dir
        self.transform = transform
            
    def __len__(self):
        return len(self.landmarks_frame)

    # get (16,240,320,3)
    def __getitem__(self, idx):
        aaa = self.landmarks_frame.iloc[idx, 0]
        video_label=self.landmarks_frame.iloc[idx,1]
        
        
        video_x=self.get_single_video_x(aaa)
        sample = {'video_x':video_x, 'video_label':video_label}

        if self.transform:
            sample = self.transform(sample)
        return sample


    def get_single_video_x(self,aaa):
        name, ext = os.path.splitext(aaa)
        name1 = os.path.join(name+'/')
        pic_path = os.path.join(root_list,name1)
        pic_names = os.listdir(pic_path)
        num = len(pic_names)
        image_start=random.randint(1,num-15)
        image_id=image_start       
        video_x=np.zeros((16,240,320,3))
        for i in range(16):
            s="%05d" % image_id
            image_name='image_'+s+'.jpg'
            image_path=os.path.join(pic_path,image_name)
            tmp_image = io.imread(image_path)
            video_x[i,:,:,:]=tmp_image
            image_id+=1
        return video_x
#%%
class ClipSubstractMean(object):
  def __init__(self, b=104, g=117, r=123):
    self.means = np.array((r, g, b))

  def __call__(self, sample):
    video_x,video_label=sample['video_x'],sample['video_label']
    new_video_x=video_x - self.means
    return {'video_x': new_video_x, 'video_label': video_label}
#%%
class Rescale(object):
    def __init__(self, output_size=(182,242)):
        assert isinstance(output_size, (int, tuple))#判断一个变量是否是某个类型可以用isinstance()判断
        self.output_size = output_size

    def __call__(self, sample):
        video_x, video_label = sample['video_x'], sample['video_label']

        h, w = video_x.shape[1],video_x[2]
        if isinstance(self.output_size, int):#判断期望的output_size变量是否是int类型
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size#长短边同比例缩放,短边变换为期望大小
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        new_video_x=np.zeros((16,new_h,new_w,3))
        for i in range(16):
            image=video_x[i,:,:,:]
            img = transform.resize(image, (new_h, new_w))
            new_video_x[i,:,:,:]=img

        return {'video_x': new_video_x, 'video_label': video_label}
#%%
class RandomCrop(object):
    """随机裁剪样本中的图像.

    Args:
       output_size(tuple或int):所需的输出大小。 如果是int,方形裁剪是。
    """

    def __init__(self, output_size=(160,160)):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        video_x, video_label = sample['video_x'], sample['video_label']

        h, w = video_x.shape[1],video_x.shape[2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        new_video_x=np.zeros((16,new_h,new_w,3))

        for i in range(16):
            image=video_x[i,:,:,:]
            image = image[top: top + new_h,left: left + new_w]
            new_video_x[i,:,:,:]=image
        return {'video_x': new_video_x, 'video_label': video_label}
#%%
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        video_x, video_label = sample['video_x'], sample['video_label']

        # swap color axis because
        # numpy image: batch_size x H x W x C
        # torch image: batch_size x C X H X W
        video_x = video_x.transpose((0, 3, 1, 2))
        video_x=np.array(video_x)
        video_label = [video_label]
        return {'video_x':torch.from_numpy(video_x),'video_label':torch.FloatTensor(video_label)}
#%%
if __name__=='__main__':
    root_list='/media/xuguangying/action recogniton/database/UTH/frames_of_video'
    info_list='/media/xuguangying/action recogniton/database/UTH/ucfTrainTestlist/trainlist01.txt'
    #myUCF101=UCF101(info_list,root_list)
    myUCF101=UCF101(info_list,root_list,transform=transforms.Compose([ClipSubstractMean(),Rescale(),RandomCrop(),ToTensor()]))
    dataloader=DataLoader(myUCF101,batch_size=8,shuffle=True,num_workers=6)
    for i_batch,sample_batched in enumerate(dataloader):
        print (i_batch,sample_batched['video_x'].size(),sample_batched['video_label'].size())#dayin meige batch de size yiji label de size
        print (i_batch,sample_batched['video_x'].size(),sample_batched['video_label']) #dayin meige batch de size yiji meige batch zhong yangben de label

相关文章

网友评论

      本文标题:2019-09-06 pytorch实现视频数据加载

      本文链接:https://www.haomeiwen.com/subject/rttsyctx.html