美文网首页大数据,机器学习,人工智能
元学习(Meta-learning)——让机器学习如何学习

元学习(Meta-learning)——让机器学习如何学习

作者: 偶尔写一写 | 来源:发表于2022-01-10 12:09 被阅读0次

1 元学习概述

元学习的意思即“学会如何学习” 。 在机器学习中,工作量最大也是最无聊的事情就是调参。我们针对每一个任务从头开始进行这种无聊的调参,然后耗费大量的时间去训练并测试效果。因此,一个直观的想法是:我们是否能让机器自己学会调参,在遇到相似任务时能够触类旁通、举一反三,用不着我们从头开始调参,也用不着大量标签数据重新进行训练。
通常的机器学习是针对一个特定的任务找到一个能够实现这个任务的function,例如猫和狗的分类任务。而元学习的目标就是要找到一个Function能够让机器自动学习原来人为确定的一些超参(Hyper-parameter),如初始化参数\theta_0、学习速率\eta、网络架构等,元学习的分类就是看学习的是什么超参。这个Function用F_{\phi}表示,F_{\phi}不是针对某一个特定任务的,而是针对一群类似的任务,例如这些任务可能包括猫和狗的分类、橘子和苹果的分类、自行车和摩托车的分类等等。这个F_{\phi}是要帮这一群类似任务找到一个好的超参,在下次再遇到相似任务的时候,初始化参数可以直接用上,用不着我们再调参了。

image.png

元学习是跨任务学习(multi-task learning),因此它需要收集多个类似任务的数据集。比如针对图片二分类任务,我们需要收集橙子和苹果训练数据和测试数据、自行车和汽车的训练数据和测试数据等等许多二分类任务的数据集。元学习的目标是:利用F_{\phi}找到最优的超参\phi,使各任务在超参\phi的基础上训练出最优参数后测试得到的损失值l^n的和最小。这句话讲起来比较难以理解,举个例子比较好明白:对于苹果和橙子的分类任务,在超参\phi的基础上利用训练数据集进行训练,得到最优参数\theta^{1*},然后再利用测试数据集对训练后的模型进行测试,测试得到的损失值使l^1;同理,可以得到自行车和汽车分类任务的测试损失值l^2,以及其他二分类任务的测试损失值l^n;元学习的目标就是要找到最优超参\phi,使所有任务的测试损失值之和最小。所以元学习的损失函数定义为,L(\phi)=\sum_{n=1}^N{l^n}这里每一个用于训练超参\phi的任务都称为训练任务,上面的N指所有训练任务的总数。如果在拿一个新的任务(该任务未在训练任务中出现过)来测试通过训练找到的超参\phi的效果,那么这个任务就称为测试任务。
我们可以看到在每一个训练任务中包含了训练数据和测试数据,当然在测试任务中也包含了训练数据和测试数据,这和普通机器学习是大不同的。这样听起来很容易让人迷糊,所以有的文献不叫训练数据和测试数据,而是把训练数据叫支持集(support set),把测试数据叫查询集(query set)。

image.png
元学习的目标是要找到超参\phi最小化损失函数L(\phi)如果能够计算梯度,那么用梯度下降法求解即可。但是有很多情况使无法求梯度的,例如对网络架构的优化,此时有些文献会采用强化学习或进化算法等方法进行求解。
image.png

2 MAML

2.1 MAML概述

在普通机器学习中,初始化参数往往是随机生成的,MAML聚焦于学习一个最好的初始化参数\phi。初始参数\phi不同,对于同一个任务n训练得到的最优参数\hat{\theta}^n不同,在任务n的测试数据集上损失值l^n(\hat{\theta}^n)不同。MAML的目标是找到最优的初始参数\phi,是所有任务的测试损失值最小,在遇到新任务时,只需基于少量标签对初始化参数\phi进行微调就可以获得很好的效果。这和前面提到的预训练有些相似,但也有些不同。

image.png

2.2 MAML的训练

MAML的训练使用梯度下降法:\phi \leftarrow \phi-\eta \nabla_{\phi} L(\phi)具体的数学推导不管它了,我们直接看上面的梯度下降是如何实施的(这里假设batch size是1):

  • ①假设刚开始的初始化参数的初始化参数是\phi_0
  • ②随机采样一个训练任务m
  • ③通过训练任务m的支持集(训练数据)求loss,然后更新一次参数得到任务m的最优参数\hat{\theta}^m,注意此时并没有更新\phi
  • ④通过训练任务m的查询集(测试数据)再求一次loss,计算梯度,然后用此梯度的方向更新\phi
  • ⑤回到②,重复②③④
    上述过程有一个很值得关注的地方是,对于任务m,更新一次参数就认为参数最优了,李宏毅认为作者之所以这么设置是因为MAML主要用于小样本学习,更新一次是怕发生过拟合的问题。
    之前说MAML和与训练模型很像,但是也有所不同。不同点是与训练模型用任务的训练数据求loss就直接更新\phi了,而不是在测试数据上二次求loss后才更新的。
image.png

结合上面的讲解,来看一看MAML原文的算法,如下图所示。首先随机在这个算法里采样一个batch的训练任务,注意这里的batch是任务而不是数据。对于这一个batch的所有任务:第5行对每一个训练任务T_i,通过支持集求loss,计算梯度;第6行根据第5行算出来的梯度更新一次参数得到\theta'_i,并且保存起来。假设一个batch有10个任务,那么这里就保存了10个模型参数\theta'_i。完成了一个batch所有任务的参数更新后,进行第8行:基于更新后的参数和所有任务的查询集计算出各自的loss,将这些loss求和,计算出梯度,利用该梯度更新初始参数。假设一个batch有10个任务,基于更新后的参数和这10个batch任务的查询集计算出10个loss,将这10个loss进行求和,并基于求和结果计算梯度,利用该梯度更新初始参数。

image.png
我截了一部分MAML的代码,通过分析代码就可以更好理解上述参数的更新过程了。以下是第一次更新参数的代码
for k in range(1, self.update_step):
       # 1.在支持集上计算loss
       logits = self.net(x_spt[i], fast_weights, bn_training=True)
       loss = F.cross_entropy(logits, y_spt[i])
       # 2. 利用上面的loss,计算梯度
       grad = torch.autograd.grad(loss, fast_weights)
       # 3. 更新参数:theta_pi = theta_pi – train_lr * grad
       fast_weights = list(map(lambda p: p[1] – self.update_lr * p[0], zip(grad, fast_weights)))
       # 4.基于更新后的参数,在查询集上计算loss
       logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
       loss_q = F.cross_entropy(logits_q, y_qry[i])
       # 5把所有loss加起来,并保存.
       losses_q[k + 1] += loss_q

以下是第二次更新参数的代码

# 将所有任务的查询集上的loss的和除以任务数目,求了个平均值
loss_q = losses_q[-1] / task_num
# 利用上面的loss算梯度,并更新初始化参数
self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()

11.3 元学习在N-ways K-shot上的应用

N-way K-shot是典型的小样本学习问题。所谓N-way K-shot是指在每一个任务里面,有N个类别,每个类别有K个样本。Omniglot是一个典型例子它包含1632个不同的字符,每个字符只有20个样本。从上面1632字符中可以构建N-way K-shot任务。例如通过下面的方式构建一个小样本分类任务:抽出20个字符出来,里面每个字符只有1个样本,我们把这个数据集作为训练集(支持集),那就是20-ways 1-shot的问题。然后再在这20个字符中取1个样本出来,作为测试集(查询集),利用训练出来的模型来判断这个样本属于哪个字符。通过这种方式,可以构建出许多任务来,如果是20 ways的就可以构建出81个任务。这些任务又可以分为训练任务和测试任务,例如将81个任务中的60个任务作为训练任务,21个任务作为测试任务。

image.png
拥有了这些数据集后,就可以来测试MAML了。以下是MAML原文的测试结果。从测试结果来看,MAML处理N-way K-shot任务是非常棒的。
image.png

相关文章

  • 元学习(Meta-learning)——让机器学习如何学习

    1 元学习概述 元学习的意思即“学会如何学习” 。 在机器学习中,工作量最大也是最无聊的事情就是调参。我们针对每...

  • 元学习

    1 元学习概念 元学习 (Meta-Learning) 通常被理解为“学会学习 (Learning-to-Lear...

  • Adaptive Cross-Modal Few-shot Le

    论文 资料1 基于度量的元学习(metric-based meta-learning)如今已成为少样本学习研究过程...

  • Meta Learning——MAML

    Meta Learning就是元学习,所谓元学习,就是学习如何去学习。这个概念放在机器学习当中就是我们希望找到一个...

  • 机器学习简介

    一、什么是机器学习? 让机器,去学习 二、机器如何学习? 三、人工智能、机器学习、深度学习之间的关系 四、平台选择...

  • 初探元学习(Meta-Learning)

    九月中下旬 好像并没有那么顺心 很多东西积压着思绪无从下手 很久没有冒痘 也冒痘了 生理和心理双重奏 希望接下...

  • Higher Library

    Higher是FAIR开源的一个元学习框架,主要针对gradient-based meta-learning。在g...

  • 3.gitchat训练营-原理

    3.1.机器是如何学习的? 什么是机器学习?机器学习就是:让计算机程序(机器),不是通过人类直接指定的规则,而是通...

  • meta-learning初印象

    本文是对李宏毅老师课程的总结。 什么是meta-learning一般的机器学习方法是学习到处理一种任务的algor...

  • 一种迁移学习和元学习的集成模型

    导言 本文提出了一种将迁移学习和元学习结合在一起的训练方法。本文是论文A Meta-Learning Approa...

网友评论

    本文标题:元学习(Meta-learning)——让机器学习如何学习

    本文链接:https://www.haomeiwen.com/subject/juiqcrtx.html