本文介绍: 因为trainer默认是以aotu_batch方式加载处理数据,因此本部分记录aotu_batch方式。另外本文记录trainer创建dataloader的基础过程。对于一些个性化加载处理、如长文档文本分类,如有必要,会另起一篇文章再进行记录

概述

transformers trainer训练、评估模型中,大致根据以下过程加载处理训练、评估数据集:

  1. 使用dataset.Dataset加载数据
  2. 使用Dataset.map自定义convert_examples_to_features函数处理Dataset中的每一行数
  3. 定义sampler,在迭代Dataloader过程中,本质迭代sampler。默认autobatch模式下,sampler在每次迭代过程中,会返回一个batch索引数值(indices),然后根据indices从Dataloader.dataset中取数据(fetch)。e.g. [self.dataset[index] for index in batch_indices]
  4. 第三步取到的数据,喂到collator_fn中,组装成tensor类型,并返回组装后的结果

因为trainer默认是以aotu_batch方式加载处理数据,因此本部分仅记录aotu_batch方式。另外本文仅记录trainer中创建dataloader的基础过程。对于一些个性化加载处理、如长文档文本分类,如有必要,会另起一篇文章再进行记录。

实例

# set up
from typing import List, Dict, Union

from datasets import Dataset
from transformers import default_data_collator
from transformers import BertTokenizer
from torch.utils.data import DataLoader, RandomSampler, BatchSampler, SequentialSampler

from config import CKP  # huggingface 中预训练模型下载本地地址

# emotion classification demo
x = [{"texts": "我爱中国。", "labels": 1}, {"texts": "今天天气真糟糕!", "labels": 0}] * 3

# 可以使用datasets.load_dataset函数,将样本数存储json格式,每一条样本占据一行
examples: Dataset = Dataset.from_list(x)
tokenizer: BertTokenizer = BertTokenizer.from_pretrained(CKP)

def convert_examples_to_features(exams: Dict[str, List[Union[str, int]]]):
    return tokenizer(exams["texts"], padding=True, max_length=20, truncation=True)

# map数中的batch=True并不影响最终结果,只是影响convert_examples_to_features签名|定义
dataset = examples.map(convert_examples_to_features, with_indices=False, with_rank=False, batched=True,
                       batch_size=1, remove_columns=["texts"])

# 验证sampler
sequence_sampler = SequentialSampler(dataset)
print(f"sequence sampler: {list(sequence_sampler)}")

random_sampler = RandomSampler(dataset)
print(f"random sampler: {list(random_sampler)}")

batch_sampler = BatchSampler(random_sampler, batch_size=2, drop_last=False)
print(f"batch sampler: {list(batch_sampler)}")

# 在convert_examples_to_features已经对input_ids进行了pad,所以使用default_data_collator
# 如果仅进行编码,即padding=False, 此处使用transformers.DataCollatorWithPadding
dataloader = DataLoader(dataset, batch_size=1, collate_fn=default_data_collator)

# add breakpoint in here, you will see
# step1. get next batch indices
# step2. fetch data according batch indices
# step3. collator data by collator_fn and return batch
for batch in dataloader:
    print(batch)

参考资料

datasets.Dataset.map方法学习笔记
transformers中的data_collator
【pytorch】Dataloader学习笔记

原文地址:https://blog.csdn.net/weixin_44815943/article/details/134672232

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任

如若转载,请注明出处:http://www.7code.cn/show_25302.html

如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱suwngjj01@126.com进行投诉反馈,一经查实,立即删除

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注