当我们深入探索深度学习的奇妙世界时,PyTorch作为一个强大且易用的框架,提供了丰富的功能来帮助我们高效地进行模型训练和数据处理。其中,DataLoader是PyTorch中一个非常核心且🔧实用的组件,它负责在模型训练过程中加载和处理数据。
通过灵活配置DataLoader的各种参数,我们可以优化数据加载速度,调整数据批次大小,甚至实现自定义的数据处理和抽样策略。
在这篇文章中,小编将详细解析DataLoader的每个参数,通过具体的示例代码展示它们的使用场景和效果,帮助你更深入地理解和使用PyTorch进行深度学习模型的开发。
一、 DataLoader的参数说明
- dataset (必需): 用于加载数据的数据集,通常是torch.utils.data.Dataset的子类实例。
- batch_size (可选): 每个批次的数据样本数。默认值为1。
- shuffle (可选): 是否在每个周期开始时打乱数据。默认为False。
- sampler (可选): 定义从数据集中抽取样本的策略。如果指定,则忽略shuffle参数。
- batch_sampler (可选): 与sampler类似,但一次返回一个批次的索引。不能与batch_size、shuffle和sampler同时使用。
- num_workers (可选): 用于数据加载的子进程数量。默认为0,意味着数据将在主进程中加载。
- collate_fn (可选): 如何将多个数据样本整合成一个批次。通常不需要指定。
- drop_last (可选): 如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次。默认为False。
二、 DataLoader的dataset参数(必需)
在实例化PyTorch的DataLoader类时,dataset参数是必需的,它指定了要从【哪个数据集对象】里面加载数据。该对象必须是torch.utils.data.Dataset的子类实例。
示例代码:
from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] # 创建自定义数据集实例 my_data = [1, 2, 3, 4, 5, 6, 7] my_dataset = MyDataset(my_data) # 使用DataLoader加载自定义数据集my_dataset dataloader = DataLoader(dataset=my_dataset)
三、 DataLoader的batch_size参数 (可选)
batch_size 参数指定了每个批次的数据样本数。默认值为1。
示例代码:
# 将批次大小设置为3,这意味着每个批次将包含3个数据样本。 dataloader = DataLoader(dataset=my_dataset, batch_size=3) for data in dataloader: print(data)
运行结果:
四、 DataLoader的shuffle参数 (可选)
shuffle参数指定是否在每个周期开始时打乱数据。默认为False。如果设置为True,则在每个周期开始时,数据将被随机打乱顺序。
示例代码:
# shuffle默认为False dataloader = DataLoader(dataset=my_dataset, batch_size=3) print("当shuffle=False时,运行结果如下:") print("*" * 30) for data in dataloader: print(data) print("*" * 30) dataloader = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True) print("当shuffle=True时,运行结果如下:") print("*" * 30) for data in dataloader: print(data) print("*" * 30)
运行结果:
五、 DataLoader的drop_last参数 (可选)
drop_last参数决定了在数据批次划分时是否丢弃最后一个不完整的批次。当数据集的大小不能被批次大小整除时,最后一个批次的大小可能会小于指定的批次大小。drop_last参数用于控制是否保留这个不完整的批次。
使用场景:
当数据集大小不能被批次大小整除时,如果最后一个批次的大小较小,可能会导致模型训练时的不稳定。通过将drop_last设置为True,可以确保每个批次的大小都相同,从而避免这种情况。
在某些情况下,丢弃最后一个批次可能不会对整体训练效果产生太大影响,但可以减少计算资源的浪费。例如,当数据集非常大时,最后一个不完整的批次可能只包含很少的数据样本,对于整体训练过程的贡献较小。
示例代码:
# drop_last默认为False dataloader = DataLoader(dataset=my_dataset, batch_size=3) print("当drop_last=False时,运行结果如下:") print("*" * 30) for data in dataloader: print(data) print("*" * 30) dataloader = DataLoader(dataset=my_dataset, batch_size=3, drop_last=True) print("当drop_last=True时,运行结果如下:") print("*" * 30) for data in dataloader: print(data) print("*" * 30)
运行结果:可以看到,当drop_last=True时,最后一个批次的数据 tensor([7]) 被舍弃了。
六、DataLoader的sampler参数 (可选)
sampler参数定义从数据集中抽取样本的策略。如果指定了sampler,则忽略shuffle参数。它可以是任何实现了__iter__()方法的对象,通常会使用torch.utils.data.Sampler的子类。
示例代码:
from torch.utils.data import SubsetRandomSampler # 创建一个随机抽样器,只选择索引为偶数的样本 【索引从0开始~】 sampler = SubsetRandomSampler(indices=[i for i in range(0, len(my_dataset), 2)]) dataloader = DataLoader(dataset=my_dataset, sampler=sampler) for data in dataloader: print(data)
运行结果:
七、 DataLoader的batch_sampler参数 (可选)
batch_sampler参数与sampler类似,但它返回的是一批次的索引,而不是单个样本的索引。不能与batch_size、shuffle和sampler同时使用。
示例代码:
from torch.utils.data import BatchSampler from torch.utils.data import SubsetRandomSampler # 创建一个随机抽样器,只选择索引为偶数的样本 【索引从0开始~】 sampler = SubsetRandomSampler(indices=[i for i in range(0, len(my_dataset), 2)]) # 创建一个批量抽样器,每个批次包含2个样本 batch_sampler = BatchSampler(sampler, batch_size=2, drop_last=True) dataloader = DataLoader(dataset=my_dataset, batch_sampler=batch_sampler) for data in dataloader: print(data)
运行结果:
八、 DataLoader的num_workers参数 (可选)
num_workers参数指定用于数据加载的子进程数量。默认为0,表示数据将在主进程中加载。增加num_workers的值可以加快数据的加载速度,但也会增加内存消耗。
示例代码:
dataloader = DataLoader(dataset=my_dataset, num_workers=4)
代码解释:在这个示例中,我们将子进程数量设置为4,这意味着将使用4个子进程并行加载数据,以加快数据加载速度。
九、DataLoader的collate_fn参数 (可选)
collate_fn参数指定如何将多个数据样本整合成一个批次,通常不需要指定。如果需要自定义批次数据的整合方式,可以提供一个可调用的函数。该函数接受一个样本【列表】作为输入,返回一个批次的数据。
示例代码:
import torch class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] # 创建自定义数据集实例 my_data = [1, 2, 3, 4, 5, 6, 7] my_dataset = MyDataset(my_data) def my_collate_fn(batch): print(type(batch)) # 将batch中的每个样本转换为pytorch的tensor并都加上10 return [torch.tensor(data) + 10 for data in batch] dataloader = DataLoader(dataset=my_dataset, batch_size=2, collate_fn=my_collate_fn) for data in dataloader: print(data)
运行结果:
参考链接:PyTorch入门必学:DataLoader(数据迭代器)参数解析与用法合集_python dataloader-CSDN博客 https://blog.csdn.net/qq_41813454/article/details/134903615