定制的ItemList
本教程主要介绍如何创建你自定义的ItemBase
或者ItemList
下面给出英文版的链接。这里就不翻译了
- Customizing datasets in fastai
- Links with the data block API
- Creating a custom ItemBase subclass
- Creating a custom ItemList subclass
fastai的定制数据集
In this tutorial, we'll see how to create custom subclasses of ItemBase
or ItemList
while retaining everything the fastai library has to offer. To allow basic functions to work consistently across various applications, the fastai library delegates several tasks to one of those specific objects, and we'll see here which methods you have to implement to be able to have everything work properly. But first let's take a step back to see where you'll use your end result.
在本教程中,我们将看下如何创建ItemBase
或 ItemList
的继承类,用来弥补fastai没有实现定义的数据类。为了使用基本函数来满足不同的应用需求,
Links with the data block API
The data block API works by allowing you to pick a class that is responsible to get your items and another class that is charged with getting your targets. Combined together, they create a pytorch Dataset
that is then wrapped inside a DataLoader
. The training set, validation set and maybe test set are then all put in a DataBunch
.
The data block API allows you to mix and match what class your inputs have, what class your targets have, how to do the split between train and validation set, then how to create the DataBunch
, but if you have a very specific kind of input/target, the fastai classes might no be sufficient to you. This tutorial is there to explain what is needed to create a new class of items and what methods are important to implement or override.
It goes in two phases: first we focus on what you need to create a custom ItemBase
class (which is the type of your inputs/targets) then on how to create your custom ItemList
(which is basically a set of ItemBase
) while highlighting which methods are called by the library.
Creating a custom ItemBase
subclass
The fastai library contains three basic types of ItemBase
that you might want to subclass:
-
Image
for vision applications -
Text
for text applications -
TabularLine
for tabular applications
Whether you decide to create your own item class or to subclass one of the above, here is what you need to implement:
Basic attributes
Those are the more important attributes your custom ItemBase
needs as they're used everywhere in the fastai library:
-
ItemBase.data
is the thing that is passed to pytorch when you want to create aDataLoader
. This is what needs to be fed to your model. Note that it might be different from the representation of your item since you might want something that is more understandable. -
__str__
representation: if applicable, this is what will be displayed when the fastai library has to show your item.
If we take the example of a MultiCategory
object o
for instance:
-
o.data
is a tensor where the tags are one-hot encoded -
str(o)
returns the tags separated by ;
If you want to code the way data augmentation should be applied to your custom Item
, you should write an apply_tfms
method. This is what will be called if you apply a transform
block in the data block API.
Example: ImageTuple
For cycleGANs, we need to create a custom type of items since we feed the model tuples of images. Let's look at how to code this. The basis is to code the data
attribute that is what will be given to the model. Note that we still keep track of the initial object (usuall in an obj
attrivute) to be able to show nice representations later on. Here the object is the tuple of images and the data their underlying tensors normalized between -1 and 1.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;">class ImageTuple(ItemBase):
def init(self, img1, img2):
self.img1,self.img2 = img1,img2
self.obj,self.data = (img1,img2),[-1+2img1.data,-1+2img2.data]
</pre>
Then we want to apply data augmentation to our tuple of images. That's done by writing and apply_tfms
method as we saw before. Here we just pass that call to the two underlying images then update the data.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;"> def apply_tfms(self, tfms, **kwargs):
self.img1 = self.img1.apply_tfms(tfms, **kwargs)
self.img2 = self.img2.apply_tfms(tfms, kwargs)
self.data = [-1+2self.img1.data,-1+2self.img2.data]
return self
</pre>
We define a last method to stack the two images next ot each other, which we will use later for a customized show_batch
/ show_results
behavior.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;"> def to_one(self): return Image(0.5+torch.cat(self.data,2)/2)
</pre>
This is all your need to create your custom ItemBase
. You won't be able to use it until you have put it inside your custom ItemList
though, so you should continue reading the next section.
Creating a custom ItemList
subclass
This is the main class that allows you to group your inputs or your targets in the data block API. You can then use any of the splitting or labelling methods before creating a DataBunch
. To make sure everything is properly working, here is what you need to know.
Class variables
Whether you're directly subclassing ItemList
or one of the particular fastai ones, make sure to know the content of the following three variables as you may need to adjust them:
-
_bunch
contains the name of the class that will be used to create aDataBunch
-
_processor
contains a class (or a list of classes) ofPreProcessor
that will then be used as the default to create processor for thisItemList
-
_label_cls
contains the class that will be used to create the labels by default
_label_cls
is the first to be used in the data block API, in the labelling function. If this variable is set to None
, the label class will be set to CategoryList
, MultiCategoryList
or FloatList
depending on the type of the first item. The default can be overridden by passing a label_cls
in the kwargs of the labelling function.
_processor
is the second to be used. The processors are called at the end of the labelling to apply some kind of function on your items. The default processor of the inputs can be overriden by passing a processor
in the kwargs when creating the ItemList
, the default processor of the targets can be overridden by passing a processor
in the kwargs of the labelling function.
Processors are useful for pre-processing some data, but you also need to put in their state any variable you want to save for the call of data.export()
before creating a Learner
object for inference: the state of the ItemList
isn't saved there, only their processors. For instance SegmentationProcessor
's only reason to exist is to save the dataset classes, and during the process call, it doesn't do anything apart from setting the classes
and c
attributes to its dataset.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">class SegmentationProcessor(PreProcessor):
def init(self, ds:ItemList): self.classes = ds.classes
def process(self, ds:ItemList): ds.classes,ds.c = self.classes,len(self.classes)
</pre>
_bunch
is the last class variable used in the data block. When you type the final databunch()
, the data block API calls the _bunch.create
method with the _bunch
of the inputs.
Keeping init arguments
If you pass additional arguments in your __init__
call that you save in the state of your ItemList
, we have to make sure they are also passed along in the new
method as this one is used to create your training and validation set when splitting. To do that, you just have to add their names in the copy_new
argument of your custom ItemList
, preferably during the __init__
. Here we will need two collections of filenames (for the two type of images) so we make sure the second one is copied like this:
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">def init(self, items, itemsB=None, **kwargs):
super().init(items, **kwargs)
self.itemsB = itemsB
self.copy_new.append('itemsB')
</pre>
Be sure to keep the kwargs as is, as they contain all the additional stuff you can pass to an ItemList
.
Important methods
- get
The most important method you have to implement is get
: this one will enable your custom ItemList
to generate an ItemBase
from the thing stored in its items
array. For instance an ImageList
has the following get
method:
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">def get(self, i):
fn = super().get(i)
res = self.open(fn)
self.sizes[i] = res.size
return res
</pre>
The first line basically looks at self.items[i]
(which is a filename). The second line opens it since the open
method is just
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">def open(self, fn): return open_image(fn)
</pre>
The third line is there for ImagePoints
or ImageBBox
targets that require the size of the input Image
to be created. Note that if you are building a custom target class and you need the size of an image, you should call self.x.size[i]
.
**Note: **If you just want to customize the way an Image
is opened, subclass Image
and just change the open
method.
- reconstruct
This is the method that is called in data.show_batch()
, learn.predict()
or learn.show_results()
to transform a pytorch tensor back in an ItemBase
. In a way, it does the opposite of calling ItemBase.data
. It should take a tensor t
and return the same kind of thing as the get
method.
In some situations (ImagePoints
, ImageBBox
for instance) you need to have a look at the corresponding input to rebuild your item. In this case, you should have a second argument called x
(don't change that name). For instance, here is the reconstruct
method of PointsItemList
:
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">def reconstruct(self, t, x): return ImagePoints(FlowField(x.size, t), scale=False)
</pre>
- analyze_pred
This is the method that is called in learn.predict()
or learn.show_results()
to transform predictions in an output tensor suitable for reconstruct
. For instance we may need to take the maximum argument (for Category
) or the predictions greater than a certain threshold (for MultiCategory
). It should take a tensor, along with optional kwargs and return a tensor.
For instance, here is the analyze_pred
method of MultiCategoryList
:
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">def analyze_pred(self, pred, thresh:float=0.5): return (pred >= thresh).float()
</pre>
thresh
can then be passed as kwarg during the calls to learn.predict()
or learn.show_results()
.
Advanced show methods
If you want to use methods such a data.show_batch()
or learn.show_results()
with a brand new kind of ItemBase
you will need to implement two other methods. In both cases, the generic function will grab the tensors of inputs, targets and predictions (if applicable), reconstruct the corresponding ItemBase
(as seen before) but it will delegate to the ItemList
the way to display the results.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">def show_xys(self, xs, ys, **kwargs)->None:
def show_xyzs(self, xs, ys, zs, **kwargs)->None:
</pre>
In both cases xs
and ys
represent the inputs and the targets, in the second case zs
represent the predictions. They are lists of the same length that depend on the rows
argument you passed. The kwargs are passed from data.show_batch()
/ learn.show_results()
. As an example, here is the source code of those methods in ImageList
:
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 25px 0px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 245, 245); border: 1px solid rgb(204, 204, 204); border-radius: 4px; white-space: pre-wrap;">def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
"Show the xs
and ys
on a figure of figsize
. kwargs
are passed to the show method."
rows = int(math.sqrt(len(xs)))
fig, axs = plt.subplots(rows,rows,figsize=figsize)
for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
xs[i].show(ax=ax, y=ys[i], **kwargs)
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, *kwargs):
"""Show xs
(inputs), ys
(targets) and zs
(predictions) on a figure of figsize
.
kwargs
are passed to the show method."""
figsize = ifnone(figsize, (6,3len(xs)))
fig,axs = plt.subplots(len(xs), 2, figsize=figsize)
fig.suptitle('Ground truth / Predictions', weight='bold', size=14)
for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
x.show(ax=axs[i,0], y=y, **kwargs)
x.show(ax=axs[i,1], y=z, **kwargs)
</pre>
Linked to this method is the class variable _show_square
of an ItemList
. It defaults to False
but if it's True
, the show_batch
method will send rows * rows
xs
and ys
to show_xys
(so that it shows a square of inputs/targets), like here for images.
Example: ImageTupleList
Continuing our custom item example, we create a custom ItemList
class that will wrap those ImageTuple
s properly. The first thing is to write a custom __init__
method (since we need a list of filenames here) which means we also have to change the new
method.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;">class ImageTupleList(ImageList):
def init(self, items, itemsB=None, **kwargs):
super().init(items, **kwargs)
self.itemsB = itemsB
self.copy_new.append('itemsB')
</pre>
We then specify how to get one item. Here we pass the image in the first list of items, and pick one randomly in the second list.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;"> def get(self, i):
img1 = super().get(i)
fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]
return ImageTuple(img1, open_image(fn))
</pre>
We also add a custom factory method to directly create an ImageTupleList
from two folders.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;"> @classmethod
def from_folders(cls, path, folderA, folderB, **kwargs):
itemsB = ImageList.from_folder(path/folderB).items
res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)
res.path = path
return res
</pre>
Finally, we have to specify how to reconstruct the ImageTuple
from tensors if we want show_batch
to work. We recreate the images and denormalize.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;"> def reconstruct(self, t:Tensor):
return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))
</pre>
There is no need to write a analyze_preds
method since the default behavior (returning the output tensor) is what we need here. However show_results
won't work properly unless the target (which we don't really care about here) has the right reconstruct
method: the fastai library uses the reconstruct
method of the target on the outputs. That's why we create another custom ItemList
with just that reconstruct
method. The first line is to reconstruct our dummy targets, and the second one is the same as in ImageTupleList
.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;">class TargetTupleList(ItemList):
def reconstruct(self, t:Tensor):
if len(t.size()) == 0: return t
return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))
</pre>
To make sure our ImageTupleList
uses that for labelling, we pass it in _label_cls
and this is what the result looks like.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;">class ImageTupleList(ImageList):
_label_cls=TargetTupleList
def init(self, items, itemsB=None, **kwargs):
super().init(items, **kwargs)
self.itemsB = itemsB
self.copy_new.append('itemsB')
def get(self, i):
img1 = super().get(i)
fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]
return ImageTuple(img1, open_image(fn))
def reconstruct(self, t:Tensor):
return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))
@classmethod
def from_folders(cls, path, folderA, folderB, **kwargs):
itemsB = ImageList.from_folder(path/folderB).items
res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)
res.path = path
return res
</pre>
Lastly, we want to customize the behavior of show_batch
and show_results
. Remember the to_one
method just puts the two images next to each other.
<pre style="box-sizing: border-box; overflow: auto; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 13px; display: block; padding: 9.5px; margin: 10px 0px 25px; line-height: 1.42857; color: rgb(51, 51, 51); word-break: break-all; word-wrap: break-word; background-color: rgb(245, 241, 224); border: 0px; border-radius: 4px; white-space: pre-wrap;"> def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):
"Show the xs
and ys
on a figure of figsize
. kwargs
are passed to the show method."
rows = int(math.sqrt(len(xs)))
fig, axs = plt.subplots(rows,rows,figsize=figsize)
for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
xs[i].to_one().show(ax=ax, **kwargs)
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):
"""Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.
kwargs
are passed to the show method."""
figsize = ifnone(figsize, (12,3*len(xs)))
fig,axs = plt.subplots(len(xs), 2, figsize=figsize)
fig.suptitle('Ground truth / Predictions', weight='bold', size=14)
for i,(x,z) in enumerate(zip(xs,zs)):
x.to_one().show(ax=axs[i,0], **kwargs)
z.to_one().show(ax=axs[i,1], **kwargs)
</pre>
网友评论