suyc's blog

哎,什么时候才能把英语学好啊🙇‍~

PyTorch读取数据入门

本文主要涉及内容:DatasetDataLoaderDatasetFloder等相关源码分析。
看了网上众多的关于这方面的文章,详细解释道理的很少,看完之后我依然很迷惑。所以我在看过看过网上一些教程和PyTorch的部分源码以后写了一些总结,基本上也是我自己在学习这部分内容时的一个经过。文章前边可能有些地方不可避免地走了弯路,但是我也希望如果你是初学者,建议看一看,思考一下,和最后的方法作对比。
本文基于PyTorch 1.1.0。

基础知识

忽略基础知识直接进入看源码

可能需要导入的包

1
2
3
4
5
from torch.utils.data.dataset import Dataset
import torch.utils.data as Data
from torchvision import transforms
import pandas as pd
from PIL import Image

Dataset基础类

首先是Dataset基础类,所有的要传入DataLoder的类都要继承这个类才行,同时必须重载__getitem____len__这两个方法。

1
2
3
4
5
6
7
8
9
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError

def __len__(self):
raise NotImplementedError

def __add__(self, other):
return ConcatDataset([self, other])

可以看到,Dataset提供了一个__add__方法,用于两个Dataset的相加,返回一个ConcatDataset对象。初始化时简单的以list的形式,把传入的参数合并在一起。

Transforms

定义Transforms可以对读入的数据进行一些变换操作,可以放下如下位置

1
2
3
4
5
6
7
8
9
10
11
class MyDataset(Dataset):
def __init__(self, ..., transforms=None):
...
self.transforms = transforms

def __getitem__(self, index):
...
data = ...
if self.transforms is not None:
data = self.transforms(data)
return data

这里的self.transforms有两层含义,一方面是可以传入一个transform,另外一方面,也可以在初始化方法中自己定义一组transform,使用链式定义的也可以,使用transforms.Compose()定义也可以。

结合Pandas读取图片

可以在初始化的过程中读入csv文件,csv文件中存好相关的配置项,然后在getitem的时候在依照配置项读取图片。这样的好处是,配合DataLoader可以边读边训练,不用先花费大量时间把图片读进来在训练。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class MyDataset(Dataset):
def __init__(self, csv_path):
# 读取 csv 文件
self.data_info = pd.read_csv(csv_path, header=None)
# 下边可以对各个列进行解析
...

def __getitem__(self, index):
# 读取图像
img_as_img = Image.open(self.image[index])

# 在csv中可以配置,是否需要额外操作
if self.operation_arr[index]:
...

# 可以把图像转换成tensor,然后从csv中读取出label信息,一起返回
return (img_as_tensor, image_label)

[1]中还给出了怎么从csv中读取像素值,然后转换成图片在返回,和上述过程大同小异。

但是,存在的问题是,DataLoader的shuffle还能不能用了,这个问题对训练来说非常重要。答案是可以!详情请继续读。

Dataloader

Dataset使得An image is read from the file on the fly,但是缺少了三个功能:batching、shuffling、multiprocessing。linkDataloader则提供了这些功能。
配合Dataloader读取数据,for循环外可加一层循环控制epoch。

1
2
3
4
5
6
7
if __name__ == "__main__":
custom_dataset = MyDataset(...)
train_loader = Data.DataLoader(dataset=custom_dataset,
batch_size=BATCH_SIZE,
shuffle=False)
for step, (image, label) in enumerate(train_loader):
...

借鉴

这里分析两个前辈给出的解决方案,看一下他们的方案有什么优缺点。

[3]中使用的方法,每次读取一个数据,基本可看作模仿上边过程的应用,问题是无法shuffle。

[2]中提及了一个比较完美的局部shuffle的解法,具体代码详见链接,这里对他的思想做个说明。
首先,他就提到了任务和要求。任务就是数据量大,无法一次读到内存中。其次是要求,首先是每次读取一部分数据,然后是重点,能够shuffle。能否shuffle对于训练很重要。在这个任务中,把数据存在了csv中,但是从上边看,其实存储在csv中图片的路径等形式依然可以采取这种方式。

  • __init__方法,__init__(self,file_path,nraws,shuffle=False),这个方法的nraws参数是用来定义每次读取多少行进行shuffle,此外,还在这个方法中计算了这个csv一共多少行。这个写法同样适用于读取图片地址。
  • def initial(self)在这个方法中,是现实读入nraws行,然后对着nraws进行shuffle,这个nraws应该对于batch来说稍微大一些,不然这个局部shuffle的意义就没了。比如作者在这里就使用的nraws=1000batch_size=64。为什么不写入init中,因为这个方法每个epoch之前都要用。
  • __len__不用多说,直接读取计算出来的行数。
  • __getitem__的重载即是每次都队列的前端拿取一个元素,如果队列为空了,则重新读取一个nraws个元素。这个方法和initial共同维持了一个队列,数据结构的美妙之处。

改进

读取图片的速度和totensor的转换速度共同影响了dataloader的时间,而且totensor占据了主要部分,可以在预处理的时候把所有的数据以pkl(或bin)的形式存储好,这样读取的时候直接读取的就是tensor的数据,这样可以显著加快速度。

对于[2]提出的方法,可以直接把一块数据存储成一个pkl,如1000个图片存成一个pkl,不用重新totensor,这样的整体速度会相对快,可以对这一组数据进行shuffle。这种改变不影响代码的整体结构,整体思路不变。

再改进

一代目步骤,基于上述两个方法的一个小方法(下文还会对这个步骤继续改进):

  1. 预处理的图片totensor后全部转化成pkl,不过这次是一对一的转换,这个转化可能需要很多时间,但是这却会给以后的训练节约很多时间。
  2. 把转化后的pkl地址和label对应着存到csv中,我的路径方式是./lable/xxx.pkl,在解析的时候只需要把label和地址拼接起来即可组成正确的相对路径。此外,在这里还可以配置是否需要其他操作。用csv的形式变化更加多样,可以实现更复杂的功能。
  3. 定义一个函数,这个函数的功能是,读取这个csv文件,然后得到lenth,对这个lenth进行shuffle。
  4. 得到一个shuffle之后的序列,使用这个序列传入MyDataset,按照shuffle之后的顺序读取,再配合DataLoader即可每次都拿到shuffle之后的数据。等到这个epoch结束后,再shuffle,即可进行下次训练。

还可以先进行数据的划分,如使用scikit-learn的分层划分,然后再对train set进行shuffle。这样需要对shuffle进行改造,即不对validation set进行shuffle,这样可以实现了训练集与测试集的划分。

看源码

通过看源码来获取一些新的视角,来帮助思考,DataLoader源码以及它调用的Sampler的源码

DataLoader

关于DataLoader的定义,直击要害,直接看两个最关键的点。第一段代码:

1
2
3
4
5
6
7
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

这段代码解释了为啥dataset需要重载__len__,实际上是传递给了各种Sampler做下一步的处理。在DataLoader的定义中没有使用dataset的长度信息

然后关注DataLoader中的第二段代码:

1
2
def __iter__(self):
return _DataLoaderIter(self)

这个可以使得DataLoader进行迭代,即可以使用for循环。廖雪峰老师的教程里有写到,这个方法如果返回xxx,则会不断的调用xxx.__next__,廖雪峰老师说通常这个就会返回self,返回对自己的迭代,这里则直接祭出了高级用法,返回一个对象了,使用这个对象的__next__

_DataLoaderIter

可以看到,这个方法实际是我们在使用for循环的时候扮演者重要角色的大BOSS。这个方法的注释中写了大量的内容,包括数据流、多线程遇到的问题,这部分不多赘述。先看一下单线程的方法,即num_workers=0的情况(TODO:补充多线程):

1
2
3
4
5
6
def __next__(self):
if self.num_workers == 0: # same-process loading
# 前提 self.sample_iter = iter(self.batch_sampler)
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
return batch

可以看到,它又通过batch_sampler的迭代获取序号的列表生成器,然后根据序号集合indices进行读取数据,再组合成一个batch。StopIteration这个用法真是绝了,廖雪峰老师也提到过,用for循环调用generator时,是拿到不到返回的值的,所以当报错的时候即为for循环结束的地方,也就是完成了一次epoch的训练。
collate_fn是一个callable的函数,如果不传入自定义的函数,则调用默认的。如果你的合并操作需要一些特殊的操作,可以自己定义这个函数,那么可以参考官方文档,那里给出了一个例子。如果使用默认的,则直接贴出源码链接。可以看到这个默认的函数非常之强大,精通各种合并。总之,合并后的数据集是一个batch。
这里是唯一一次用到__getitem__方法。下文会有分析。

BatchSampler

BatchSampler是对其他的Sampler进行了封装,代码非常简单。
有意思的的关于Batchs长度的计算,用全部数据除以batch size(单纯的觉得这里很有意思)。

1
2
3
4
5
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []

这段代码表示,每次从sampler中获取的都是一个序号,把这个序号组合到batch size的大小,然后yield。关于yield在这里的用法,新更新了文章,进行了详细解释link

关于使用到的这两个Sampler。SquentialSampler就不多说了,顺序直接迭代,返回一二三四五六七。RandomSampler则是对序号打乱顺序,这里多说一点,按照这个程序中的调用,返回的是torch.randperm(n),即n个数,没有重复的。

这里的n,就是n = len(self.data_source)也就是我们一开始重载的__len__,直到这里才起作用。

此外,这个Random还支持一个叫做replacement的参数,即可以有重复的取样。这个函数还有更厉害的组合拳,可以去看源码。

结论

思考

至此,大概的脉络已经分析出来了。如何调用大数据的时候写好我们的Dataset类呢?我们可以看到两个重载的方法出现的位置,答案似乎已经清楚了,就是结合Pandas读取图片一节描述的内容。只要我们能够清楚地知道csv文件的长度,那么Dataloader就会帮我们进行shuffle,而不用我们自己进行任何操作。在读取图片的时候,也就是上文唯一提到的__getitem__出现的位置,我们只需要按照index的位置获取csv中的地址,读入图片即可,如果已经转换成了pkl,那么会大大加快我们的速度。

二代目步骤:
维持一代目的1、2步骤不变,3、4的步骤可以交给DataLoader来完成了~也就是说我们只需要定义好myDataSet中的__getitem__,用好index这个关键变量,那么剩下的工作就交给DataLoader的shuffle就行了。

那么现在再来看这三个引用。ref1悄悄地告诉了我们最终结果,但是它却没有解释清楚为什么要这么用。ref2提出了一个局部shuffle的办法,也是一个不错的点子。同时也需要注意的是,这个方法中把__getitem__当作了__iter__来用,因为每次总是取队列的最顶端,没有使用到传入的index参数。 ref3最主要的问题也是没有用到index。所以ref2和ref3无论DataLoader的shuffle参数是True还是False,都无关紧要了,因为根本不会用到这个属性。

简化策略

其实,上边使用csv的形式,会提供更大的可操作性,可以自己定制更加灵活的形式解决自己的问题。

但是!PyTorch早就为我们准备好了一个方便便捷的方法啦!
好多盆友使用的torchvision.datasets.ImageFolder,在这里同样附上源码,这个类就是继承的同文件下的DatasetFolder,只是封装好了一组后缀,也没有提供新的方法,所以这里以DatasetFloder的源码展开看一下。

文件夹格式留意看:

1
2
3
4
5
6
7
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext

首先看参数:__init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None)
root是根路径;loader是一个callable的方法,传入一个路径,读出一个东西;ext是指定后缀;后边两个transform;最后一个参数也是一个callsble的方法,判断某个文件是否是符合我们需要的方法。下边看一下这个类的运行流程:

  1. 读入root下的文件夹的名字,将名字对应转换成数字序列(0、1、2、3)这种形式。也就是说,无论label怎么存,都可以!
    也就是这段代码class_to_idx = {classes[i]: i for i in range(len(classes))}
  2. 下边进入make_dataset这个方法,这个方法会遍历root下所有的文件,如果指定了is_valid_file则使用这个函数判断文件是否合法,否则依照传入的ext后缀判断这个文件是否是我们需要的。
    经过一连串的判断,最后返回N个由地址和label组成的二元tuple,N是文件个数,label是数字形式的。
  3. 在getitem时,会调用loader进行读取数据,读取完后会对数据和label做transform。然后返回。
    当然在这里我们可以定义自己的loader。如果我们自定义的loader读上来就是pkl,那么就不需要再定义transform了,非常简单。
  4. 最后return sample, target,那么在for的每次迭代时就可以拿到这两个元素。

三代目步骤:
这个类把我们在再改进中的第2步骤都实现了有没有!而且还能配合上DataLoader的shuffle!现在可能只留下第1步需要我们做(如果不转pkl,那么什么也不用做,直接使用默认的loader即可)。
或者我们提前把测试集和训练集划分开,然后定义两个DatasetFolder就行啦!

演示

1
2
train_dataset = Datasets.ImageFolder(img_path, transform=transforms.ToTensor())
train_dataset = Datasets.DatasetFolder(img_path, extensions='pkl', loader=pklloder)

如果直接使用默认的图片后缀,如jpg、png等,可以直接使用ImageFolder即可,使用transform将读入的图片进行一些操作。因为我的图片都已经裁剪为224*224了,所以只需要totensor就行。如果像我一样使用了pkl,那么就需要自己写一个loader函数。

如果不使用shuffle,那么读入的顺序就像这样:

1
2
3
4
5
6
7
8
9
./crops/1/0044_035.pkl
./crops/1/0044_036.pkl
./crops/1/0044_037.pkl
./crops/1/0044_038.pkl
./crops/1/0044_039.pkl
./crops/1/0044_040.pkl
./crops/1/0044_041.pkl
./crops/1/0044_042.pkl
./crops/1/0044_043.pkl

如果使用了shuffle,那么就变成了这样:
1
2
3
4
5
6
7
8
9
./crops/4/0230_509.pkl
./crops/1/1483_064.pkl
./crops/3/1553_005.pkl
./crops/2/1742_005.pkl
./crops/3/1016_002.pkl
./crops/4/0230_564.pkl
./crops/4/2003_036.pkl
./crops/2/2607_008.pkl
./crops/3/0450_004.pkl

使用scikit-learn进行分层的划分数据,即按照label的比例划分:

1
2
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.1, stratify=y, random_state=0)

一个224*224*3大小的图片,png要80~90kB,jpg格式的要10~30kB,而pkl的要588KB!

reference

[1]PyTorch 中自定义数据集的读取方法小结
[2]pytorch加载大数据
[3]pytorch load huge dataset
[4]Pytorch 1.1.0 Docs

在下才疏学浅,如有描述有误的地方,还望不吝赐教。