PyTorch入门必学:DataLoader(数据迭代器)参数解析与用法合集

当我们深入探索深度学习的奇妙世界时,PyTorch作为一个强大且易用的框架,提供了丰富的功能来帮助我们高效地进行模型训练和数据处理。其中,DataLoader是PyTorch中一个非常核心且🔧实用的组件,它负责在模型训练过程中加载和处理数据。

通过灵活配置DataLoader的各种参数,我们可以优化数据加载速度,调整数据批次大小,甚至实现自定义的数据处理和抽样策略。

在这篇文章中,小编将详细解析DataLoader的每个参数,通过具体的示例代码展示它们的使用场景和效果,帮助你更深入地理解和使用PyTorch进行深度学习模型的开发。

一、 DataLoader的参数说明

  1. dataset (必需):  用于加载数据的数据集,通常是torch.utils.data.Dataset的子类实例。
  2. batch_size (可选):  每个批次的数据样本数。默认值为1。
  3. shuffle (可选):  是否在每个周期开始时打乱数据。默认为False
  4. sampler (可选):  定义从数据集中抽取样本的策略。如果指定,则忽略shuffle参数。
  5. batch_sampler (可选): sampler类似,但一次返回一个批次的索引。不能与batch_size、shuffle和sampler同时使用。
  6. num_workers (可选):  用于数据加载的子进程数量。默认为0,意味着数据将在主进程中加载。
  7. collate_fn (可选):  如何将多个数据样本整合成一个批次。通常不需要指定。
  8. drop_last (可选):   如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次。默认为False

二、 DataLoader的dataset参数(必需)

在实例化PyTorch的DataLoader类时,dataset参数是必需的,它指定了要从【哪个数据集对象】里面加载数据。该对象必须是torch.utils.data.Dataset的子类实例。

示例代码:

from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# 创建自定义数据集实例
my_data = [1, 2, 3, 4, 5, 6, 7]
my_dataset = MyDataset(my_data)

# 使用DataLoader加载自定义数据集my_dataset
dataloader = DataLoader(dataset=my_dataset)

三、 DataLoader的batch_size参数 (可选)

batch_size 参数指定了每个批次的数据样本数。默认值为1。

示例代码:

# 将批次大小设置为3,这意味着每个批次将包含3个数据样本。
dataloader = DataLoader(dataset=my_dataset, batch_size=3)

for data in dataloader:
	print(data)

运行结果:

四、 DataLoader的shuffle参数 (可选)

shuffle参数指定是否在每个周期开始时打乱数据。默认为False。如果设置为True,则在每个周期开始时,数据将被随机打乱顺序。

示例代码:

# shuffle默认为False
dataloader = DataLoader(dataset=my_dataset, batch_size=3)

print("当shuffle=False时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

dataloader = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True)

print("当shuffle=True时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

运行结果:

五、 DataLoader的drop_last参数 (可选)

drop_last参数决定了在数据批次划分时是否丢弃最后一个不完整的批次。当数据集的大小不能被批次大小整除时,最后一个批次的大小可能会小于指定的批次大小。drop_last参数用于控制是否保留这个不完整的批次。

使用场景:

当数据集大小不能被批次大小整除时,如果最后一个批次的大小较小,可能会导致模型训练时的不稳定。通过将drop_last设置为True,可以确保每个批次的大小都相同,从而避免这种情况。
在某些情况下,丢弃最后一个批次可能不会对整体训练效果产生太大影响,但可以减少计算资源的浪费。例如,当数据集非常大时,最后一个不完整的批次可能只包含很少的数据样本,对于整体训练过程的贡献较小。

示例代码:

# drop_last默认为False
dataloader = DataLoader(dataset=my_dataset, batch_size=3)

print("当drop_last=False时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

dataloader = DataLoader(dataset=my_dataset, batch_size=3, drop_last=True)

print("当drop_last=True时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

运行结果:可以看到,当drop_last=True时,最后一个批次的数据 tensor([7]) 被舍弃了。

六、DataLoader的sampler参数 (可选)

sampler参数定义从数据集中抽取样本的策略。如果指定了sampler,则忽略shuffle参数。它可以是任何实现了__iter__()方法的对象,通常会使用torch.utils.data.Sampler的子类。

示例代码:

from torch.utils.data import SubsetRandomSampler

# 创建一个随机抽样器,只选择索引为偶数的样本 【索引从0开始~】
sampler = SubsetRandomSampler(indices=[i for i in range(0, len(my_dataset), 2)])
dataloader = DataLoader(dataset=my_dataset, sampler=sampler)

for data in dataloader:
    print(data)

运行结果:

七、 DataLoader的batch_sampler参数 (可选)

batch_sampler参数与sampler类似,但它返回的是一批次的索引,而不是单个样本的索引。不能与batch_sizeshufflesampler同时使用。

示例代码:

from torch.utils.data import BatchSampler
from torch.utils.data import SubsetRandomSampler

# 创建一个随机抽样器,只选择索引为偶数的样本 【索引从0开始~】
sampler = SubsetRandomSampler(indices=[i for i in range(0, len(my_dataset), 2)])

# 创建一个批量抽样器,每个批次包含2个样本
batch_sampler = BatchSampler(sampler, batch_size=2, drop_last=True)
dataloader = DataLoader(dataset=my_dataset, batch_sampler=batch_sampler)

for data in dataloader:
    print(data)

运行结果:

八、 DataLoader的num_workers参数 (可选)

num_workers参数指定用于数据加载的子进程数量。默认为0,表示数据将在主进程中加载。增加num_workers的值可以加快数据的加载速度,但也会增加内存消耗。

示例代码:

dataloader = DataLoader(dataset=my_dataset, num_workers=4)

代码解释:在这个示例中,我们将子进程数量设置为4,这意味着将使用4个子进程并行加载数据,以加快数据加载速度。

九、DataLoader的collate_fn参数 (可选)

collate_fn参数指定如何将多个数据样本整合成一个批次,通常不需要指定。如果需要自定义批次数据的整合方式,可以提供一个可调用的函数。该函数接受一个样本【列表】作为输入,返回一个批次的数据。

示例代码:

import torch
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# 创建自定义数据集实例
my_data = [1, 2, 3, 4, 5, 6, 7]
my_dataset = MyDataset(my_data)

def my_collate_fn(batch):
    print(type(batch))
    # 将batch中的每个样本转换为pytorch的tensor并都加上10
    return [torch.tensor(data) + 10 for data in batch]

dataloader = DataLoader(dataset=my_dataset, batch_size=2, collate_fn=my_collate_fn)

for data in dataloader:
    print(data)

运行结果:


参考链接:PyTorch入门必学:DataLoader(数据迭代器)参数解析与用法合集_python dataloader-CSDN博客 https://blog.csdn.net/qq_41813454/article/details/134903615

感谢你的阅读


欢迎评论交流


忽如一夜春风来,千树万树梨花开。

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇