在 PyTorch 中,可以使用 torch.distributions.Categorical
来基于给定的概率分布进行采样。
下面是一个示例:
import torch
import torch.distributions as dist
# 创建一个大小为 (1, n) 的一行张量表示概率分布
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
# 使用 Categorical 分布进行采样
m = dist.Categorical(probs)
samples = m.sample((10,)) # 生成 10 个样本
print(samples)
在上述示例中,我们首先创建了一个大小为 (1, n)
的一行张量 probs
表示概率分布。然后,我们使用 torch.distributions.Categorical
类来创建一个 Categorical
分布对象 m
。该分布由给定的概率分布 probs
定义。接下来,我们使用 sample()
方法从分布中生成 10
个样本,并将其存储在 samples
中。最后,我们打印输出了生成的样本。
请注意,sample()
方法返回的张量的形状由传递给它的参数决定。在上述示例中,我们用 (10,)
指定了要生成 10
个样本,所以返回的张量的形状为 (10,)
。如果没有指定参数,则默认生成单个样本。此外,Categorical
分布还提供了 log_prob()
方法,用于计算给定样本的对数概率。
如果需要基于给定的概率分布进行不重复采样,可以使用 torch.multinomial()
函数以及循环来实现。
下面是一个示例:
import torch
# 创建一个大小为 (1, n) 的一行张量表示概率分布
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
# 确定要生成的样本数量
num_samples = 3
# 初始化空列表用于存储已经出现过的样本索引
sampled_indices = []
# 循环采样直到得到足够数量的不同样本
while len(sampled_indices) < num_samples:
# 使用 multinomial 函数生成一个样本索引
index = torch.multinomial(probs, 1).item()
# 如果该索引在 sampled_indices 中没有出现过,则将其加入 sampled_indices 列表中
if index not in sampled_indices:
sampled_indices.append(index)
# 将采样结果存储在样本张量中
samples = probs[sampled_indices]
print(samples)
在上述示例中,我们首先创建了一个大小为 (1, n)
的一行张量 probs
表示概率分布。接下来,我们确定要生成的样本数量,并初始化一个空列表 sampled_indices
,用于存储已经出现过的样本索引。然后,我们使用 torch.multinomial()
函数生成一个样本索引,并判断该索引是否已经在 sampled_indices
中出现过。如果该索引没有出现过,则将其加入 sampled_indices
列表中,否则继续循环生成新的样本索引。最后,我们将采样结果存储在样本张量 samples
中,并打印输出。
请注意,上述代码只实现了一种简单的不重复抽样方法。如果需要进行更高效的不重复抽样,可以使用其他算法,如 Fisher-Yates shuffle
算法等。
可以使用PyTorch的torch.multinomial()
函数来进行不重复抽样。该函数接受一个概率分布张量和要抽取样本的数量作为输入,并返回一个整数张量,表示从概率分布中抽取的样本的索引。如果希望进行不重复抽样,可以在调用torch.multinomial()
函数时将参数replacement
设置为False
。例如:
import torch
# 创建概率分布张量
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
# 进行不重复抽样
samples = torch.multinomial(probs, num_samples=3, replacement=False)
print(samples)
输出结果将是一个长度为3的整数张量,表示从概率分布中抽取的三个不重复样本的索引。
原文地址:https://blog.csdn.net/AdamCY888/article/details/134762311
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.7code.cn/show_49244.html
如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱:suwngjj01@126.com进行投诉反馈,一经查实,立即删除!