音效素材网提供各类素材,打造精品素材网站!

站内导航 站长工具 投稿中心 手机访问

音效素材

Pytorch中Softmax和LogSoftmax的使用详解
日期:2021-09-08 14:38:47   来源:脚本之家

一、函数解释

1.Softmax函数常用的用法是指定参数dim就可以:

(1)dim=0:对每一列的所有元素进行softmax运算,并使得每一列所有元素和为1。

(2)dim=1:对每一行的所有元素进行softmax运算,并使得每一行所有元素和为1。

class Softmax(Module):
    r"""Applies the Softmax function to an n-dimensional input Tensor
    rescaling them so that the elements of the n-dimensional output Tensor
    lie in the range [0,1] and sum to 1.
    Softmax is defined as:
    .. math::
        \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
    Shape:
        - Input: :math:`(*)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(*)`, same shape as the input
    Returns:
        a Tensor of the same dimension and shape as the input with
        values in the range [0, 1]
    Arguments:
        dim (int): A dimension along which Softmax will be computed (so every slice
            along dim will sum to 1).
    .. note::
        This module doesn't work directly with NLLLoss,
        which expects the Log to be computed between the Softmax and itself.
        Use `LogSoftmax` instead (it's faster and has better numerical properties).
    Examples::
        >>> m = nn.Softmax(dim=1)
        >>> input = torch.randn(2, 3)
        >>> output = m(input)
    """
    __constants__ = ['dim']
 
    def __init__(self, dim=None):
        super(Softmax, self).__init__()
        self.dim = dim
 
    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None
 
    def forward(self, input):
        return F.softmax(input, self.dim, _stacklevel=5)
 
    def extra_repr(self):
        return 'dim={dim}'.format(dim=self.dim)

2.LogSoftmax其实就是对softmax的结果进行log,即Log(Softmax(x))

class LogSoftmax(Module):
    r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
    input Tensor. The LogSoftmax formulation can be simplified as:
    .. math::
        \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
    Shape:
        - Input: :math:`(*)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(*)`, same shape as the input
    Arguments:
        dim (int): A dimension along which LogSoftmax will be computed.
    Returns:
        a Tensor of the same dimension and shape as the input with
        values in the range [-inf, 0)
    Examples::
        >>> m = nn.LogSoftmax()
        >>> input = torch.randn(2, 3)
        >>> output = m(input)
    """
    __constants__ = ['dim']
 
    def __init__(self, dim=None):
        super(LogSoftmax, self).__init__()
        self.dim = dim
 
    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None
 
    def forward(self, input):
        return F.log_softmax(input, self.dim, _stacklevel=5)

二、代码示例

输入代码

import torch
import torch.nn as nn
import numpy as np
 
batch_size = 4
class_num = 6
inputs = torch.randn(batch_size, class_num)
for i in range(batch_size):
    for j in range(class_num):
        inputs[i][j] = (i + 1) * (j + 1)
 
print("inputs:", inputs)

得到大小batch_size为4,类别数为6的向量(可以理解为经过最后一层得到)

tensor([[ 1., 2., 3., 4., 5., 6.],
[ 2., 4., 6., 8., 10., 12.],
[ 3., 6., 9., 12., 15., 18.],
[ 4., 8., 12., 16., 20., 24.]])

接着我们对该向量每一行进行Softmax

Softmax = nn.Softmax(dim=1)
probs = Softmax(inputs)
print("probs:\n", probs)

得到

tensor([[4.2698e-03, 1.1606e-02, 3.1550e-02, 8.5761e-02, 2.3312e-01, 6.3369e-01],
[3.9256e-05, 2.9006e-04, 2.1433e-03, 1.5837e-02, 1.1702e-01, 8.6467e-01],
[2.9067e-07, 5.8383e-06, 1.1727e-04, 2.3553e-03, 4.7308e-02, 9.5021e-01],
[2.0234e-09, 1.1047e-07, 6.0317e-06, 3.2932e-04, 1.7980e-02, 9.8168e-01]])

此外,我们对该向量每一行进行LogSoftmax

LogSoftmax = nn.LogSoftmax(dim=1)
log_probs = LogSoftmax(inputs)
print("log_probs:\n", log_probs)

得到

tensor([[-5.4562e+00, -4.4562e+00, -3.4562e+00, -2.4562e+00, -1.4562e+00, -4.5619e-01],
[-1.0145e+01, -8.1454e+00, -6.1454e+00, -4.1454e+00, -2.1454e+00, -1.4541e-01],
[-1.5051e+01, -1.2051e+01, -9.0511e+00, -6.0511e+00, -3.0511e+00, -5.1069e-02],
[-2.0018e+01, -1.6018e+01, -1.2018e+01, -8.0185e+00, -4.0185e+00, -1.8485e-02]])

验证每一行元素和是否为1

# probs_sum in dim=1
probs_sum = [0 for i in range(batch_size)]
 
for i in range(batch_size):
    for j in range(class_num):
        probs_sum[i] += probs[i][j]
    print(i, "row probs sum:", probs_sum[i])

得到每一行的和,看到确实为1

0 row probs sum: tensor(1.)
1 row probs sum: tensor(1.0000)
2 row probs sum: tensor(1.)
3 row probs sum: tensor(1.)

验证LogSoftmax是对Softmax的结果进行Log

# to numpy
np_probs = probs.data.numpy()
print("numpy probs:\n", np_probs)
 
# np.log()
log_np_probs = np.log(np_probs)
print("log numpy probs:\n", log_np_probs)

得到

numpy probs:
[[4.26977826e-03 1.16064614e-02 3.15496325e-02 8.57607946e-02 2.33122006e-01 6.33691311e-01]
[3.92559559e-05 2.90064461e-04 2.14330270e-03 1.58369839e-02 1.17020354e-01 8.64669979e-01]
[2.90672347e-07 5.83831024e-06 1.17265590e-04 2.35534250e-03 4.73083146e-02 9.50212955e-01]
[2.02340233e-09 1.10474026e-07 6.03167746e-06 3.29318427e-04 1.79801770e-02 9.81684387e-01]]
log numpy probs:
[[-5.4561934e+00 -4.4561934e+00 -3.4561934e+00 -2.4561932e+00 -1.4561933e+00 -4.5619333e-01]
[-1.0145408e+01 -8.1454077e+00 -6.1454072e+00 -4.1454072e+00 -2.1454074e+00 -1.4540738e-01]
[-1.5051069e+01 -1.2051069e+01 -9.0510693e+00 -6.0510693e+00 -3.0510693e+00 -5.1069155e-02]
[-2.0018486e+01 -1.6018486e+01 -1.2018485e+01 -8.0184851e+00 -4.0184855e+00 -1.8485421e-02]]

验证完毕

三、整体代码

import torch
import torch.nn as nn
import numpy as np
 
batch_size = 4
class_num = 6
inputs = torch.randn(batch_size, class_num)
for i in range(batch_size):
    for j in range(class_num):
        inputs[i][j] = (i + 1) * (j + 1)
 
print("inputs:", inputs)
Softmax = nn.Softmax(dim=1)
probs = Softmax(inputs)
print("probs:\n", probs)
 
LogSoftmax = nn.LogSoftmax(dim=1)
log_probs = LogSoftmax(inputs)
print("log_probs:\n", log_probs)
 
# probs_sum in dim=1
probs_sum = [0 for i in range(batch_size)]
 
for i in range(batch_size):
    for j in range(class_num):
        probs_sum[i] += probs[i][j]
    print(i, "row probs sum:", probs_sum[i])
 
# to numpy
np_probs = probs.data.numpy()
print("numpy probs:\n", np_probs)
 
# np.log()
log_np_probs = np.log(np_probs)
print("log numpy probs:\n", log_np_probs)

基于pytorch softmax,logsoftmax 表达

import torch
import numpy as np
input = torch.autograd.Variable(torch.rand(1, 3))

print(input)
print('softmax={}'.format(torch.nn.functional.softmax(input, dim=1)))
print('logsoftmax={}'.format(np.log(torch.nn.functional.softmax(input, dim=1))))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持。

    您感兴趣的教程

    在docker中安装mysql详解

    本篇文章主要介绍了在docker中安装mysql详解,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编...

    详解 安装 docker mysql

    win10中文输入法仅在桌面显示怎么办?

    win10中文输入法仅在桌面显示怎么办?

    win10系统使用搜狗,QQ输入法只有在显示桌面的时候才出来,在使用其他程序输入框里面却只能输入字母数字,win10中...

    win10 中文输入法

    一分钟掌握linux系统目录结构

    这篇文章主要介绍了linux系统目录结构,通过结构图和多张表格了解linux系统目录结构,感兴趣的小伙伴们可以参考一...

    结构 目录 系统 linux

    PHP程序员玩转Linux系列 Linux和Windows安装

    这篇文章主要为大家详细介绍了PHP程序员玩转Linux系列文章,Linux和Windows安装nginx教程,具有一定的参考价值,感兴趣...

    玩转 程序员 安装 系列 PHP

    win10怎么安装杜比音效Doby V4.1 win10安装杜

    第四代杜比®家庭影院®技术包含了一整套协同工作的技术,让PC 发出清晰的环绕声同时第四代杜比家庭影院技术...

    win10杜比音效

    纯CSS实现iOS风格打开关闭选择框功能

    这篇文章主要介绍了纯CSS实现iOS风格打开关闭选择框,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作...

    css ios c

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的办法

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的

    Win7给电脑C盘扩容的办法大家知道吗?当系统分区C盘空间不足时,就需要给它扩容了,如果不管,C盘没有足够的空间...

    Win7 C盘 扩容

    百度推广竞品词的投放策略

    SEM是基于关键词搜索的营销活动。作为推广人员,我们所做的工作,就是打理成千上万的关键词,关注它们的质量度...

    百度推广 竞品词

    Visual Studio Code(vscode) git的使用教程

    这篇文章主要介绍了详解Visual Studio Code(vscode) git的使用,小编觉得挺不错的,现在分享给大家,也给大家做个参考。...

    教程 Studio Visual Code git

    七牛云储存创始人分享七牛的创立故事与

    这篇文章主要介绍了七牛云储存创始人分享七牛的创立故事与对Go语言的应用,七牛选用Go语言这门新兴的编程语言进行...

    七牛 Go语言

    Win10预览版Mobile 10547即将发布 9月19日上午

    微软副总裁Gabriel Aul的Twitter透露了 Win10 Mobile预览版10536即将发布,他表示该版本已进入内部慢速版阶段,发布时间目...

    Win10 预览版

    HTML标签meta总结,HTML5 head meta 属性整理

    移动前端开发中添加一些webkit专属的HTML5头部标签,帮助浏览器更好解析HTML代码,更好地将移动web前端页面表现出来...

    移动端html5模拟长按事件的实现方法

    这篇文章主要介绍了移动端html5模拟长按事件的实现方法的相关资料,小编觉得挺不错的,现在分享给大家,也给大家...

    移动端 html5 长按

    HTML常用meta大全(推荐)

    这篇文章主要介绍了HTML常用meta大全(推荐),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参...

    cdr怎么把图片转换成位图? cdr图片转换为位图的教程

    cdr怎么把图片转换成位图? cdr图片转换为

    cdr怎么把图片转换成位图?cdr中插入的图片想要转换成位图,该怎么转换呢?下面我们就来看看cdr图片转换为位图的...

    cdr 图片 位图

    win10系统怎么录屏?win10系统自带录屏详细教程

    win10系统怎么录屏?win10系统自带录屏详细

    当我们是使用win10系统的时候,想要录制电脑上的画面,这时候有人会想到下个第三方软件,其实可以用电脑上的自带...

    win10 系统自带录屏 详细教程

    + 更多教程 +
    ASP编程JSP编程PHP编程.NET编程python编程