使用持续学习在预训练代码语言模型中进行超出分布泛化

互联不一般哥 2024-07-21 16:10:41

引用

Martin Weyssow, Xin Zhou, Kisub Kim, David Lo, and Houari Sahraoui. 2023. On the Usage of Continual Learning for Out-of-Distribution Generalization in Pre-trained Language Models of Code. In Proceedings of the 31st ACM Joint European Software Engineering Conference and Symposium on the Foundations of Software Engineering (ESEC/FSE ’23), December 3–9, 2023, San Francisco, CA, USA. ACM, New York, NY, USA, 13 pages.

https://doi.org/10.1145/3611643.3616244

摘要

预训练语言模型(PLMs)已成为深度学习中代码领域的一种普遍技术,利用两阶段的预训练和微调过程,获取关于代码的通用知识,并在各种下游任务中进行专业化。然而,软件代码库动态变化的特性对PLMs的有效性和鲁棒性构成了挑战。特别是,真实世界的场景可能导致预训练数据与测试数据的分布显著不同,即分布偏移,从而降低PLMs在下游任务中的性能。本文强调了需要使代码的PLMs适应随时间变化的软件数据的必要性,这是先前工作中被忽视的关键问题。本研究的动机在于考虑PLMs在非稳态环境下的应用,即微调数据根据软件演化场景随时间变化。具体来说,我们设计了一个场景,模型需要从包含新的、未见过的API的程序流中学习。我们研究了两种广泛使用的PLM架构,即GPT2解码器和RoBERTa编码器,在两个下游任务上,即API调用和API使用预测。我们证明了先前工作中最常用的微调技术在处理API动态性方面不够健壮,导致之前获取的知识丧失,即灾难性遗忘。为解决这些问题,我们实施了五种持续学习方法,包括基于重放和基于正则化的方法。我们的研究结果表明,有效利用这些简单方法可以有效减轻PLMs在两个下游任务中的灾难性遗忘,并实现可比或更优的性能。

1 引言

过去关于代码表示学习的研究利用了一种普遍的两阶段过程,有效地训练和专门化预训练语言模型(PLMs)用于代码相关的下游任务。第一阶段即预训练,通过在大规模数据集上进行自监督学习来优化模型,获取关于代码的通用知识。这一预训练阶段使模型能够在第二阶段,即微调阶段,适应各种下游任务。先前的研究通常采用经典的迁移学习方法,即通过在任务特定的损失函数和数据上微调模型,将预训练知识“转移”到目标任务中。这种方法在自然语言处理(NLP)和深度学习领域已经取得成功。

在这一视角下,以往的研究主要集中在静态设置上,忽视了模型需要适应随时间变化的环境和数据的实际需求。大多数先前的研究建议在静态环境中使用迁移学习来微调模型,而不是解决现实世界场景中软件代码库、软件库和API动态变化的问题,导致底层软件数据分布随时间变化,即所谓的概念漂移。通过忽视软件代码库的实际演变,现有研究集中于使用静态数据集对代码的预训练模型进行微调和测试。实际上,软件演变可能导致训练数据和测试数据之间显著差异,即分布偏移,这种现象在这些静态数据集中通常不存在。当模型投入生产并需要处理真实世界数据时,这种差异尤为显著。我们认为,创建能够反映真实软件演变场景和分布变化的数据集,以正确评估代码模型的超出分布(OOD)泛化能力,这具有重要意义。

图1:对预训练的代码语言模型进行持续微调。在预训练后,模型需要随时间适应新的超出分布(OOD)的程序数据

现有关于OOD泛化的研究设计了基于源代码数据中各种分布变化的数据集。然而,它们未解决将预训练模型持续适应OOD数据流的问题。我们研究的主要目标是探索模型在更好适应软件演变场景中的方法。在这一背景下,我们提出以下问题:如何有效地持续微调预训练的代码模型,以适应新数据,同时考虑过去的数据?(见图1)。近年来,持续学习(CL)已经出现以解决这一问题,其与包括计算机视觉和自然语言处理在内的广泛研究领域具有相关性。尽管迁移学习方法并非专为持续学习场景而设计,但它们仍可在数据流上微调模型。然而,这些方法缺乏鲁棒性,可能导致意外现象,如灾难性遗忘。存在其他策略,例如使用新数据从头开始重新训练模型,但由于预训练阶段的巨大计算强度,这种方法也不实际。受现有模型问题的启发,我们尝试研究更强大和可扩展的微调技术。我们假设,在这一背景下,持续学习技术可能比经典迁移学习带来显著的好处。

本文深入探讨了PLMs在持续微调场景中的行为,如图1所示。我们的目标有二:(1)评估PLMs在代码的超出分布泛化能力;(2)研究在存在OOD数据流的情况下,有效的持续微调策略。具体来说,我们在一个反映典型软件代码库可能在实践中演变的场景中解决这些挑战。为此,我们创建了五个OOD领域数据集,每个数据集在预训练阶段通过引入新的、未见过的API来模拟持续微调的数据流。每个数据集相对于预训练数据都包含显著的分布偏移。因此,我们的设置确立了一个OOD泛化问题。我们考虑了两种广泛使用的模型架构:类似GPT2的解码器和类似RoBERTa的编码器,它们都在代码预训练上进行了微调。为了消除预训练数据和微调数据之间的数据泄露,我们决定从头开始预训练我们的模型。我们不研究像CodeBERT或CodeT5这样的流行PLMs,因为它们可能存在潜在的数据泄露问题,即在预训练中已经接触到OOD数据,我们无法精确控制。我们在两个下游任务上评估模型的性能:API调用预测和API使用预测。在第一个任务中,模型尝试预测API调用,在调用站点之前出现的代码标记,结果是一个代码标记。另一方面,第二个任务涉及生成整个API使用,生成与前一个任务相同输入格式的代码标记序列。这两个任务共同评估了模型在不同代码生成场景下的性能。

我们首先研究OOD数据对GPT2-like解码器在zero-shot设置下的影响,即在新的OOD数据上不进行微调模型。我们发现,与分布内数据相比,模型在OOD数据上普遍无法泛化,表现出显著的性能差距,六个评估指标中高达75%的BLEU得分下降。这一发现强烈表明,单独的预训练并不能解决PLMs在代码中的OOD泛化问题。接着,我们在持续微调场景中评估模型的性能,使用经典的迁移学习方法观察到明显的灾难性遗忘现象。为了解决这一问题,我们实施了一种直观但计算效率低下的累积微调方法,利用无限大小的重放缓冲区。结果显示,这种方法显著减轻了遗忘问题。最后,我们比较了经典迁移学习与基于重放和基于正则化的持续学习方法的性能。重放方法被认为是持续学习中难以超越的策略,它包括维护一个包含先前看到数据样本的小型重放缓冲区。在微调过程中,我们将重放缓冲区与当前OOD训练集结合使用,微调PLM。我们探索了基于正则化的方法,包括EWC、SI和RWalk,它们在微调时向损失函数添加正则化项,防止PLM的重要参数发生大幅变化。我们选择这些方法,因为它们在计算上效率高,广为人知,并且被认为是持续学习文献中的强基线。我们发现,这些持续学习方法显著减少了遗忘问题,同时在两个任务中实现了类似或更优的效果。

据我们所知,这项工作是首次尝试研究PLMs在代码的OOD泛化中的持续微调。我们相信,持续学习在这一研究领域的影响具有潜力,特别是由于软件数据随时间的内在演变。我们的贡献总结如下:

(1)我们展示了PLMs在OOD数据上泛化能力的不足,并强调了在这一领域进一步研究的必要性。

(2)我们在持续学习环境中研究了两种预训练模型架构在代码中的行为,显示经典迁移学习的缺乏鲁棒性和容易出现灾难性遗忘。

(3)我们比较了五种持续学习方法,包括基于重放和基于正则化的方法,在我们的持续微调场景中。我们展示了持续学习在经典迁移学习方面的优越性。

(4)我们提供了一个包括Java代码片段及其API使用序列的大规模数据集,包括预训练数据,并提供了提取OOD数据的过程。

2 实验设置

2.1 数据集构建

从头开始对语言模型进行预训练需要大量的数据,以使模型的损失函数收敛。基于这一点,我们利用Google BigQuery从GitHub上抓取了大量程序,构建了我们的大型数据集。具体来说,我们专注于Java程序,首先收集了所有存储在GitHub仓库中的Java文件。接着,我们使用Groum工具提取了这些Java文件中定义的所有方法及其API使用序列。我们提取这些API使用序列是为了方便数据划分,并获得每个API在方法中的位置,以便实现下游任务。每个样本由一个方法中的所有标记组成。为了避免实验中的重复偏差,我们通过比较每个方法的哈希值对数据集进行了去重。最终的数据集包含超过6800万条Java方法。在我们的实验中,我们对这6800万条方法进行了洗牌,随机选取了1000万条方法作为初始数据集。图2展示了我们如何进一步划分数据以进行实验。由于我们选择了从头开始进行PLMs(预训练语言模型)预训练,因此需要将数据划分为用于模型预训练的分布内(ID)数据和用于持续微调的超出分布(OOD)数据。同时,我们还需要适当地提取OOD数据,以符合在微调过程中向PLM引入新API的场景。

图2:将ID数据用于模型预训练以及OOD数据用于持续微调的过程

超出分布数据集——。我们创建了五个OOD数据集,, ..., 。每个OOD数据集代表一个唯一的领域,涵盖了API的高级功能。例如,我们有一个名为Security的领域,包含与编写安全相关代码的API相关的内容,还有一个名为Guava的领域,仅包括来自Guava库的API。为了创建每个OOD数据集,我们从与其领域相关的包/库中随机选择了10个接口。最后,我们将每个领域数据集与所选接口中的所有API关联起来,排除类构造方法。表1总结了数据集,包含总共147,245个样本。

表1:超出分布数据集的细节

为了形成每个OOD数据集,我们从1000万条Java方法池中选择了至少操作一个相关API的样本。在我们的实验中,我们对与OOD数据集, ..., 相关的训练集进行了顺序的持续微调。因此,为了防止数据泄露,我们排除了操作多个领域API的样本。这一过程消除了对我们OOD场景有效性的重大威胁,确保在微调过程中API的引入是符合预期的。为了获得具有代表性的测试集,我们从每个OOD数据集中操作每个API的样本中随机选择了10%来形成相应的领域测试集。

分布内数据集——。我们通过从初始数据中移除中的样本来获得。然后,我们对进行了洗牌,并随机选择了50,000个样本作为测试集()。DID_test包含了剩余的样本用于预训练,我们随机选择了100,000个样本用于模型验证(DID_PT_valid)。特别地,这些样本使我们能够在独立的验证集上监控模型损失的演变,以避免对预训练数据的过拟合。总的来说,预训练集DID_PT_train包含了超过900万条样本来进行模型的预训练。

2.1 模型和任务的搭建

在本工作中,我们考虑了两种广泛使用的代码深度学习架构:类似RoBERTa的编码器和类似GPT2的解码器。由于大规模语言模型(LLMs)的预训练需要大量计算资源,我们故意在研究中排除了它们的使用。为了全面处理我们的超出分布(OOD)场景,我们必须先从头开始预训练模型,然后持续微调新的、未见过的API代码。因此,我们选择评估两种较小的模型架构,即RoBERTa和GPT-2,它们要么作为像CodeBERT这样的PLMs的基础模型,要么作为生成模型。

解码器–Mₑₓₑₚ。解码器模型基于GPT-2架构,具有相同的超参数,并使用因果语言建模目标进行预训练,即从左到右的下一个标记预测。由于我们在有限的资源下进行实验,我们实现了一个具有1.1亿可训练参数的小型GPT-2版本,并将模型预训练了10万步。我们使用早停策略基于验证集DID_PT_valid上的损失选择最佳模型检查点。

编码器–Mₑₙₚ。编码器模型基于RoBERTa架构,具有相同的超参数,并使用遮蔽语言建模目标进行预训练。我们实现了一个基础版本的RoBERTa模型,拥有1.25亿可训练参数,并且类似于解码器模型进行了预训练,使用早停策略选择最佳检查点。需要注意的是,与解码器Mₑₓₑₚ相反,编码器的架构不适用于生成任务。因此,我们在其上添加了一个随机初始化的语言建模头部,用于使用OOD数据集进行微调。因此,我们预期Mₑₙₚ比Mₑₓₑₚ不稳定,并且更容易发生灾难性遗忘,因为语言建模头部没有进行预训练。这种比较为我们提供了对两种不同架构稳健性的宝贵见解。

下游任务。我们采用两个下游任务来评估我们的代码PLMs学习和适应新软件数据的能力,这些数据随时间引入了新的、未见过的API。图3说明了这两个任务。对于API调用预测,模型将方法中调用API之前的所有标记作为输入,并生成前k个候选项。对于API使用预测,模型使用与API调用预测任务相同的标记作为输入,但试图生成整个API使用(接口名称、方法名称、参数和语法标记),这是一个更具挑战性的任务。评估PLMs在这两个下游任务上的表现,旨在选择API先前知识在任务中似乎至关重要的任务。因此,选择这两个任务与我们持续的OOD场景密切相关,并允许我们直接衡量OOD API对PLMs效果的影响。

图3:下游任务概述。在API调用预测任务中,模型输出一个顶部k个候选项的列表,以预测API调用标记(即最小)。在API使用预测任务中,模型试图预测构成API使用的所有标记(接口名称、方法名称、参数和语法标记)。模型仅利用遗留上下文标记生成预测。

评估指标。我们使用先前工作中使用的指标来衡量模型在两个下游任务上的性能。对于API调用预测,我们报告精确匹配@k(EM@k),它在考虑k个候选项列表时给出了正确预测的百分比。对于API使用预测,我们报告BLEU分数、精确匹配(EM)和CodeBLEU 。

为了衡量模型在持续学习环境中的表现,我们使用两个元指标,这些指标是从先前工作中改编的:平均(A)和遗忘(F)指标。我们定义在测试数据集上的平均AM如下:

在这里,j指的是包括第i个后续增量学习步骤。表示评估指标,例如,在测试集上第j个时间步骤上的EM@k。T表示微调步骤的最大数量,在本文中为五步。平均指标仅提供模型准确性的信息,但不揭示其减轻灾难性遗忘能力的见解。我们定义在时间步k上的指标M的忘记如下:

这是计算指标的第一次时间,例如在i时间步在数据集上对模型进行了微调后,k时间步上的指标。提供了有关模型稳定性的信息,即其不过去忘记的能力。因此,越低越好。

实施细节。为了预训练Mₑₓₑₚ和Mₑₙₚ,我们使用了四个Tesla V100-SXM2-32GB GPU。预训练Mₑₓₑₚ约需7天时间,预训练Mₑₙₚ约需2天时间。对于微调和推断,我们使用了一个单独的Tesla V100-SXM2-32GB GPU。我们使用了Huggingface的库来实现模型并存储数据集。为了实现持续学习方法,我们使用了Avalanche。

3 实验

3.1 RQ1:在zero-shot中泛化成ID和OOD数据效果如何?

在本实验中,我们评估了模型在zero-shot设置下在ID和OOD测试数据集下处理两个下游任务的性能。我们没有对进行实验,因为该模型在未经过调优前无法生成代码,因此无法在zero-shot设置下运行。实验的目的有两个。首先旨在验证我们研究的实验设置。如果我们观察到在ID和OOD数据集上获得的评估指标有显著差异,这将表明我们的OOD场景是合理且良好构建的。其次,ID和OOD测试数据之间存在显著差距意味着像这样的PLM仍需要使用强大的迁移学习或持续学习技术,才能在不忘记历史数据的情况下泛化到新数据。

API调用预测。表2报告了ID和OOD测试数据上的EM@1、EM@5和EM@10。结果显示,该模型在ID数据上表现良好,EM@1几乎达到了73%。然而,在OOD数据上测试时,性能显著下降。考虑到更多API调用候选项时,性能下降不那么严重,但仍然是一个显著问题。此外,我们观察到性能下降的变化在不同的OOD数据集中有所不同。例如,该模型在安全领域()上表现比在Android()或Web()等领域更好,这些领域可能包含更多特定领域的API调用。

表2:在zero-shot中使用的API调用预测结果

API使用预测。表3报告了ID和OOD测试数据上的BLEU分数、EM和CodeBLEU分数。结果显示,与ID数据相比,模型在OOD数据上表现不佳,所有评估指标都显著下降。此外,我们注意到EM和CodeBLEU指标的变化与API调用预测任务中的EM@k指标类似。Android和Web领域的性能下降最为严重,而安全领域的下降最为轻微。

表3:在zero-shot中使用的API使用预测结果

我们的结果表明,模型未经过调优)在ID数据上表现强劲,但无法泛化到OOD数据。我们的发现也验证了我们的OOD数据集作为一个在连续环境中测试模型适应新数据能力的有效性。

3.2 RQ2:使用经典的迁移学习,模型是否遗忘了过去的数据?

这一部分我们评估经典迁移学习在持续学习场景中的表现,即在先前工作中使用的微调方式。我们依次在OOD数据集, ..., 上对模型进行微调。我们称这种方法为“朴素微调”,这是连续学习文献中常用的术语,因为它不利用机制来解决灾难性遗忘问题。我们报告了API调用预测的EM@1和API使用预测的EM结果。图4展示了模型在OOD测试集上的EM@1和EM指标在微调步骤中的变化。每个热图的列表示模型在特定测试集上性能的变化,每行表示一个新的增量微调步骤。为了量化灾难性遗忘,我们在表4中报告了EM@1和EM指标的遗忘指标。

图4:朴素微调方法的结果

表4:朴素微调基线的遗忘指标

我们的结果和观察表明,忘记历史数据的问题对于我们研究的两个模型都是一个重要问题,而对于模型来说,这个问题更为严重。即使在低数量的微调步骤中,灾难性遗忘也已经非常显著。通过考虑更多的微调步骤,我们可以预期这个问题会更加严重。我们得出结论,经典的迁移学习在先前工作中最常用的微调方法并没有强大和稳健到能够使模型在适应新数据的同时保留对历史数据的知识。

转述:邹英龙

0 阅读:0

互联不一般哥

简介:感谢大家的关注