您现在的位置是:首页 >其他 >使用PyTorch处理AG_NEWS新闻分类数据集网站首页其他

使用PyTorch处理AG_NEWS新闻分类数据集

墨D芯 2025-07-23 00:01:02
简介使用PyTorch处理AG_NEWS新闻分类数据集

如何使用PyTorch处理AG_NEWS新闻分类数据集,主要包括数据加载、文本分词、词汇表构建以及预处理流水线的定义。

1. 数据加载与查看

from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(root='../datasets', split='train')
print("连续三个next(train_iter)得到的结果:")
print(next(train_iter))
print(next(train_iter))
print(next(train_iter))
  • 功能:加载AG_NEWS训练集,并打印前三个样本。
  • 输出示例:每个样本为元组 (标签, 文本),如 (3, "Wall St. Bears Claw Back Into the Black...")
  • 注意:AG_NEWS的标签为 1~4,分别对应类别:World, Sports, Business, Sci/Tec

2. 分词器与词汇表构建

tokenizer = get_tokenizer('basic_english')  # 基础英文分词器(小写+按空格分割)
train_iter = AG_NEWS(root='../datasets', split='train')  # 重新加载迭代器

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)  # 生成分词后的列表

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])  # 未登录词映射到<unk>
  • 分词器:将文本转换为小写并按空格分割(例如 "Hello World" → ["hello", "world"])。
  • 词汇表:基于训练数据的所有分词结果构建,添加 <unk> 处理未见过的词。
  • 注意:重新加载 train_iter 是为了避免之前打印样本时的数据消耗。

3. 词汇表测试

print("vocab('Mary Had a Little Lamb'.lower().split())")
print(vocab(['mary', 'had', 'a', 'little', 'lamb']))  # 手动分词结果

print("vocab(tokenizer('Mary Had a Little Lamb'.lower()))")
print(vocab(tokenizer('mary had a little lamb')))     # 分词器处理后的结果
  • 输出:返回每个词在词汇表中的索引列表,例如 [1032, 0, 3, 543, 789]
  • 作用:验证分词和词汇表是否正常工作。

4. 预处理流水线

def text_pipeline(x):
    return vocab(tokenizer(x))  # 文本→分词→索引列表

def label_pipeline(c):
    return int(c) - 1  # 标签1~4 → 0~3(适应模型输出)
  • 文本流水线:将原始文本转换为模型可接受的索引序列。
  • 标签流水线:将标签调整为从0开始(PyTorch模型通常要求类别标签为 0~N-1)。

5. 预处理测试

print("text_pipeline('Mary Had a Little Lamb'.lower())")
print(text_pipeline('mary had a little lamb'))  # 输出索引列表

print("label_pipeline('4')")
print(label_pipeline('4'))  # 输出3(对应Sci/Tec)
  • 验证结果:确保文本转换和标签调整符合预期。

潜在问题与改进

  1. 迭代器重置:多次使用 train_iter 时需确保数据重新加载(代码中已正确处理)。
  2. 标签类型:假设数据集中标签是字符串(如 '3'),需转换为整数。若实际标签为整数,需调整 label_pipeline
  3. 分词优化basic_english 分词器较简单,可替换为更复杂的模型(如BERT分词器)。
  4. 填充与截断:文本序列需统一长度(代码未处理,实际训练时需添加)。
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。