suyc's blog

哎,什么时候才能把英语学好啊🙇‍~

Iterator和Iterable

这两天再看Nvidia的DALI的用法,和Pytorch的代码在进行比较的时候,突然一个问题困惑到了我,__iter____next__要怎么用?
这里推荐三个资料:
[1] 定制类 - 廖雪峰
[2] 生成器 - 廖雪峰
[3] 【python魔术方法】迭代器(iternext) - Liburro

推荐先把上边三个资料看一下。在这里我想分享三段代码,即我对Iterator和Iterable的理解。

代码段一

1
2
3
4
5
6
7
>>> def odd():
... print('step 1')
... yield 1
... print('step 2')
... yield(3)
... print('step 3')
... yield(5)

这个例子来自廖雪峰老师的博客。这个是一个函数,只要在函数中包含了yield,这个函数即被看作是generator,可以看到,这个函数既是Iterable又是Iterator。(PS:我印象里Python对类和函数的界限分的不是很清楚,都是面向对象的)

1
2
3
4
>>> isinstance(odd(), Iterator)
True
>>> isinstance(odd(), Iterable)
True

直接使用for是最简单的办法:

1
2
3
4
5
6
7
8
9
>>> for o in odd():
... print(o)
...
step 1
1
step 2
3
step 3
5

含有step的是odd函数中的输出,135则是for循环中的输出。分开来看一下:

1
2
3
4
5
6
>>> next(odd())
step 1
1
>>> next(odd())
step 1
1

当调用next的时候,每次都会调用__next__,因为传入了不同的对象,所以会生成两个generator。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> o = odd()
>>> next(o)
step 1
1
>>> next(o)
step 2
3
>>> next(o)
step 3
5
>>> next(o)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration

用上面向对象以后,这样就和for循环一样的,注意的是当结束时会抛出StopIteration,for循环遇到这个错误即会停止。

代码段二

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> class Test(object):
... def __init__(self):
... pass
... def __iter__(self):
... for i in range(5):
... yield i
...
>>> x = Test()
>>> print(isinstance(x, Iterator))
False
>>> print(isinstance(x, Iterable))
True
>>> print(isinstance(iter(x), Iterator))
True
>>> print(isinstance(iter(x), Iterable))
True

上边资料三中也介绍到了这个部分,只有__iter__那么这个对象只是一个Iterable的对象。
但是当调用iter时即调用了对象的__iter__方法,因为这个方法内有yield,所以这个方法并不会返回,而是直接把这个方法作为了一个generator,也就变成了和代码段一一样的一个情况,使得新的对象既可Iterable又可Iterator,当然也可以使用for去遍历。资料三中提到,在调用for的时候,实际上会自动转换为for i in iter(x)

1
2
3
4
5
6
7
8
>>> for i in x:
... print(i)
...
0
1
2
3
4

代码段三

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
>>> class Test(object):
... def __init__(self):
... self.start = 1
... self.end = 5
... def __iter__(self):
... return self
... def __next__(self):
... if self.start < self.end:
... ret = self.start
... self.start += 1
... return ret
... else:
... raise StopIteration
...
>>>
>>> for x in Test():
... print(x)
...
1
2
3
4

这种写法也就是最常见的,上边资料三中对其作了详细的介绍,就不再展开了。原理就是不断调用Test()对象的__next__后遇到StopIteration即停止。

那么如果__next__中含有yield呢?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
>>> class Test(object):
... def __init__(self):
... pass
... def __iter__(self):
... return self
... def __next__(self):
... for i in range(5):
... yield i
...
>>> for i in Test():
... print(i)
...
<generator object Test.__next__ at 0x7f0325db1950>
<generator object Test.__next__ at 0x7f0325db1a50>
<generator object Test.__next__ at 0x7f0325db1950>
<generator object Test.__next__ at 0x7f0325db1a50>
<generator object Test.__next__ at 0x7f0325db1950>
<generator object Test.__next__ at 0x7f0325db1a50>
<generator object Test.__next__ at 0x7f0325db1950>
<generator object Test.__next__ at 0x7f0325db1a50>
<generator object Test.__next__ at 0x7f0325db1950>
<generator object Test.__next__ at 0x7f0325db1a50>
<generator object Test.__next__ at 0x7f0325db1950>
<generator object Test.__next__ at 0x7f0325db1a50>

这个for循环会无休止的产生新的generator,不会自动停止,因为每次调用__next__都会遇到yield,都会把当前这个__next__封装成一个新的generator。有了上边的例子其实不难理解,每个generator都应该用一个for循环去迭代,于是可以这么做:

1
2
3
4
5
6
7
8
9
10
11
12
>>> for i in Test():
... for ii in i:
... print(ii, end=' ')
... print()
...
0 1 2 3 4
0 1 2 3 4
0 1 2 3 4
0 1 2 3 4
0 1 2 3 4
0 1 2 3 4
0 1 2 3 4

就产生了无休止的迭代。

回归PyTorch

之前我写的一篇关于PyTorch读取数据的文章link,其中有BatchSampler、Dataloader等的用法和相关的代码,推荐看一看。BatchSampler类的实现方法就是上述代码段二
再附上一段删减版的,全部的可见Doc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class BatchSampler(Sampler):

def __init__(self, sampler, batch_size, drop_last):
...

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

其中self.sampler也是这种结构,没有__next__,当它StopIteration时,即结束循环,那么batch中可能装了一部分数据,也可能恰好是空的,于是就要看是否drop_last。当最后一个yield结束以后,这个对象也会StopIteration。

当调用Dataloader时,实际上时在调用_DataLoaderIter__next__,也就是代码段三,它在处理单线程的读取的时候,并没有主动抛出StopIteration,那它是怎么停的呢?
实际上就是这行代码:

1
2
3
self.sample_iter = iter(self.batch_sampler)
...
indices = next(self.sample_iter) # may raise StopIteration

可以看到,它先把BatchSampler对象用iter包装上,然后迭代,当这个对象执行完,即最后一个yield执行后,这里并没有接收这个raise,继续抛出,那么for x, y in Dataloader也就可以停止了。