当前位置: 首页 > news >正文

【torch.utils.data.sampler】采样器的解析和使用

文章目录

  • torch.utils.data.sampler
    • 内置的Sampler
      • 基类 Sampler
      • 顺序采样 SequentialSampler
      • 随机采样 RandomSampler
      • 子集随机采样 SubsetRandomSampler
      • 加权随机采样 WeightedRandomSampler
      • 批采样 BatchSampler

torch.utils.data.sampler

内置的Sampler

基类 Sampler

sampler 采样器,是一个迭代器。PyTorch提供了多种采样器,用户也可以自定义采样器。所有sampler都是承 torch.utils.data.sampler.Sampler这个抽象类。

class Sampler(object):
    r"""Base class for all Samplers.
    """
    def __init__(self, data_source):
        pass
    def __iter__(self):
        raise NotImplementedError

顺序采样 SequentialSampler

  • 功能
    • 顺序地对元素进行采样,总是以相同的顺序。
  • 参数
    • data_source(Dataset): 采样的数据集

初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()只负责返回数据源包含的数据个数;iter()方法负责返回一个可迭代对象,这个可迭代对象是由range产生的顺序数值序列,也就是说迭代是按照顺序进行的。

class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(range(len(self.data_source)))
    def __len__(self):
        return len(self.data_source)
  • 例子
# 定义数据和对应的采样器
data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)
# 迭代获取采样器生成的索引
for index in seq_sampler:
    print("index: {}, data: {}".format(str(index), str(data[index])))

得到下面的输出,说明Sequential Sampler产生的索引是顺序索引:

index: 0, data: 17
index: 1, data: 22
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8

随机采样 RandomSampler

  • 功能
    • 随机抽取元素。如果没有替换,则从打乱的数据集中采样。 如果有替换,则用户可以指定:attr:num_samples
  • 参数
    • data_source (Dataset): 采样的数据集
    • replacement (bool): 如果为 True抽取的样本是有放回的。默认是False
    • num_samples (int): 抽取样本的数量,默认是len(dataset)。当replacementTrue的时应该被被实例化
class RandomSampler(Sampler):
    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        # 这个参数控制的应该为是否重复采样
        self.replacement = replacement
        self._num_samples = num_samples
    def num_samples(self):
        # dataset size might change at runtime
        # 初始化时不传入num_samples的时候使用数据源的长度
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples
    # 返回数据集长度
    def __len__(self):
        return self.num_samples
		# 索引生成
		def __iter__(self):
		    n = len(self.data_source)
		    if self.replacement:
		        # 生成的随机数是可能重复的
		        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
		    # 生成的随机数是不重复的
		    return iter(torch.randperm(n).tolist())

randint()函数生成的随机数学列是可能包含重复数值的,而randperm()函数生成的随机数序列是绝对不包含重复数值的

  • 例子
'''不使用replacement,生成的随机索引不重复'''
ran_sampler = sampler.RandomSampler(data_source=data)
# 得到下面输出
index: 0, data: 17
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8
index: 1, data: 22

'''使用replacement,生成的随机索引有重复'''
ran_sampler = sampler.RandomSampler(data_source=data, replacement=True)
# 得到下面的输出
index: 0, data: 17
index: 4, data: 8
index: 3, data: 41
index: 4, data: 8
index: 2, data: 3

子集随机采样 SubsetRandomSampler

  • 功能
    • 从给定的索引列表中随机抽取元素,不进行替换。
  • 参数
    • indices (sequence): 索引列表
class SubsetRandomSampler(Sampler):
    def __init__(self, indices):
        # 数据集的切片,比如划分训练集和测试集
        self.indices = indices
    def __iter__(self):
        # 以元组形式返回不重复打乱后的“数据”
        return (self.indices[i] for i in torch.randperm(len(self.indices)))
    def __len__(self):
        return len(self.indices)

_iter__()返回的并不是随机数序列,而是通过随机数序列作为indices的索引,进而返回打乱的数据本身。需要注意的仍然是采样是不重复的,也是通过randperm()函数实现的。

  • 例子

下面将data划分为train和val两个部分

sub_sampler_train = sampler.SubsetRandomSampler(indices=data[0:2])
sub_sampler_val = sampler.SubsetRandomSampler(indices=data[2:])
# 下面是train输出
index: 17
index: 22
*************
# 下面是val输出
index: 8
index: 41
index: 3

加权随机采样 WeightedRandomSampler

  • 功能
    • 按照给定的概率权重weights, 对元素进行采样
  • 参数
    • weights权重序列
    • num_samples采样数
    • replacement 抽取的样本是否有放回
class WeightedRandomSampler(Sampler):
    def __init__(self, weights, num_samples, replacement=True):
         # ...省略类型检查
         # weights用于确定生成索引的权重
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        # 用于控制是否对数据进行有放回采样
        self.replacement = replacement
    def __iter__(self):
        # 按照加权返回随机索引值
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

__iter__()方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的

  • 例子
# 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
weights = torch.Tensor([0, 10, 1.1, 1.1, 1.1, 1.1, 1.1])
wei_sampler = sampler.WeightedRandomSampler(weights=weights, num_samples=6, replacement=True)
# 下面是输出:
index: 1
index: 2
index: 3
index: 4
index: 1
index: 1

从输出可以看出,位置[1]由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。

批采样 BatchSampler

  • 功能
    • 包装另一个采样器以生成一个小批量索引。
  • 参数
    • sampler对应前面介绍的XxxSampler类实例
    • batch_size 批量大小
    • drop_last为“True”时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据
class BatchSampler(Sampler):
    def __init__(self, sampler, batch_size, drop_last):# ...省略类型检查
        # 定义使用何种采样器Sampler
        self.sampler = sampler
        self.batch_size = batch_size
        # 是否在采样个数小于batch_size时剔除本次采样
        self.drop_last = drop_last
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            # 如果采样个数和batch_size相等则本次采样完成
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        # for结束后在不需要剔除不足batch_size的采样个数时返回当前batch        
        if len(batch) > 0 and not self.drop_last:
            yield batch
    def __len__(self):
        # 在不进行剔除时,数据的长度就是采样器索引的长度
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  • 例子

下面的例子中batch sampler采用的采样器为顺序采样器:

seq_sampler = sampler.SequentialSampler(data_source=data)
batch_sampler = sampler.BatchSampler(seq_sampler, 3, False)
# 下面是输出
batch: [0, 1, 2]
batch: [3, 4]

相关文章:

  • Springboot魅力乡村管理系统srb4s计算机毕业设计-课程设计-期末作业-毕设程序代做
  • 【星球】【slam】 研讨会(5)VINS:Mono+Fusion 重点提炼
  • 机器学习笔记之受限玻尔兹曼机(三)推断任务
  • 【ASE+python学习】-批量识别石墨烯团簇结构中的吡啶氮,并删除与其相连的氢
  • 【算法】排序——冒泡排序
  • 【数据结构】二分搜索树
  • MySQL 中的 sql_mode 选项以及配置
  • mysql数据库
  • JSP | 基于Servlet和JSP改造oa项目
  • 2022SDNU-ACM结训赛题解
  • JavaWeb_第5章_会话技术_Cookie+Session
  • 新手入门SLAM必备资料
  • python -- PyQt5(designer)中文详细教程(四)事件和信号
  • 【大数据入门核心技术-Hive】MySQL5.7安装
  • FISCO BCOS(二十五)———多机部署
  • 讲点登录业务
  • Python实现基于用户的协同过滤推荐算法构建电影推荐系统
  • 跟着实例学Go语言(二)
  • 树的递归算法与非递归(迭代)的转化重点理解代码(上篇)
  • OpenCV3图像处理笔记