关于tf.distributions的那些事儿

xiaoxiao2025-05-15  28

引子

       在学习各类Machine Learning方法时,免不了要与“分布”打交道。我们有时候需要计算某个分布的熵,有时候需要计算两个分布之间的交叉熵或KL散度。当然,这可以通过使用Numpy中的numpy.random.normal之类的函数来实现,但是我们更希望能够按照TensorFlow计算图的形式来实现,这样的话,可以更好地利用TensorFlow的一些优势(如一次性计算,共享计算结果等)。

简介

       tf.distributions是TensorFlow提供的核心组件之一,用于实现一些常见的概率分布,并给出了一系列的辅助计算函数。首先,该组件中有Distribution基类、RegisterKL类、ReparameterizationType类。其中RegisterKL类是一个注册KL散度实现的装饰器,也即可以为某个分布添加KL散度的计算功能。此外,该组件还实现了以下分布:

Bernoulli Distribution;Beta Distribution;Categorical Distribution;Dirichlet Distribution;Dirichlet-Multinomial Distribution;Exponential Distribution;Gamma Distribution;Laplace Distribution;Multinomial Distribution;Normal Distribution;StudentT Distribution;Uniform Distribution.

       下面我们以Normal Distribution为例来进行介绍。

tf.distributions.Normal

       Normal类型定义在tensorflow/python/ops/distributions/normal.py文件中。其__init__函数定义如下:

__init__( loc, scale, validate_args=False, allow_nan_stats=True, name='Normal' )

       其中loc为高斯分布的均值 μ \mu μ,scale为标准差 σ \sigma σ。        在Normal类中,有如下properties:

allow_nan_statsbatch_shapedtypeevent_shapelocnameparametersreparameterization_typescalevalidate_args        关于这些性质的解释就不赘述了。下面列出Normal类中给出的一些方法(列出来只是为了能够一目了然):batch_shape_tensor (name=‘batch_shape_tensor’)cdf (value, name=‘cdf’)copy (**override_parameters_kwargs)covariance (name=‘covariance’)cross_entropy (other, name=‘cross_entropy’)entropy (name=‘entropy’)event_shape_tensor (name=‘event_shape_tensor’)is_scalar_batch (name=‘is_scalar_batch’)is_scalar_event (name=‘is_scalar_event’)kl_divergence (other, name=‘kl_divergence’)log_cdf (value, name=‘log_cdf’)log_prob (value, name=‘log_prob’)log_survival_function (value, name=‘log_survival_function’)mean (name=‘mean’)mode (name=‘mode’)param_shapes (cls, sample_shape, name=‘DistributionParamShapes’)param_static_shapes (cls, sample_shape)prob (value, name=‘prob’)quantile (value, name=‘quantile’)sample (sample_shape=(), seed=None, name=‘sample’)stddev (name=‘stddev’)survival_function (value, name=‘survival_function’)variance (name=‘variance’)        其中有计算熵的entropy方法,计算交叉熵的cross_entropy方法,计算KL散度(相对熵)的kl_divergence方法,这些方法为我们提供了极大的便利。

尾声

       本文对tf.distributions进行了极简的介绍,大家如果对此有兴趣的话可以直接在TensorFlow官网查看,具体见:tf.distributions。        大家周五快乐~

转载请注明原文地址: https://www.6miu.com/read-5030132.html

最新回复(0)