博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tf中的SGDOptimizer学习【转载】
阅读量:6611 次
发布时间:2019-06-24

本文共 3531 字,大约阅读时间需要 11 分钟。

转自:

1.tf.train.GradientDescentOptimizer

其中有函数:

1.1apply_gradients

apply_gradients(    grads_and_vars,    global_step=None,    name=None)

 

Apply gradients to variables.

This is the second part of minimize(). It returns an Operation that applies gradients.

将梯度应用到变量上。它是minimize函数的第二部分。

1.2compute_gradients

compute_gradients(    loss,    var_list=None,    gate_gradients=GATE_OP,    aggregation_method=None,    colocate_gradients_with_ops=False,    grad_loss=None)

 Compute gradients of loss for the variables in var_list.

This is the first part of minimize(). It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a Tensor, an IndexedSlices, or None if there is no gradient for the given variable.

计算var-list的梯度,它是minimize函数的第一部分,返回的是一个list,对应每个变量都有梯度。准备使用apply_gradient函数更新。

下面重点来了: 

参数:

  • loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable.
  • var_list: Optional list or tuple of  to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.

 loss就是损失函数,没啥了。

 这个第二个参数变量列表通常是不传入的,那么计算谁的梯度呢?上面说,默认的参数列表是计算图中的 GraphKeys.TRAINABLE_VARIABLES.

 去看这个的API发现:

 tf.GraphKeys

 The following standard keys are defined:

找到TRAINABLE_VARIABLES是:

  • TRAINABLE_VARIABLES: the subset of Variable objects that will be trained by an optimizer. See for more details.

然后再去看:

tf.trainable_variables

tf.trainable_variables(scope=None)

 

Returns all variables created with trainable=True.

When passed trainable=True, the Variable() constructor automatically adds new variables to the graph collectionGraphKeys.TRAINABLE_VARIABLES.

This convenience function returns the contents of that collection.

Returns:

A list of Variable objects.

然后再去看一下tf.Variable函数:

tf.Variable

__init__(    initial_value=None,    trainable=True,    collections=None,    validate_shape=True,    caching_device=None,    name=None,    variable_def=None,    dtype=None,    expected_shape=None,    import_scope=None,    constraint=None,    use_resource=None,    synchronization=tf.VariableSynchronization.AUTO,    aggregation=tf.VariableAggregation.NONE)

 

并且:

  • trainable: If True, the default, also adds the variable to the graph collection GraphKeys.TRAINABLE_VARIABLES. This collection is used as the default list of variables to use by the Optimizer classes.

 默认为真,并且加入可训练变量集中,所以:

在word2vec实现中,

with tf.device('/cpu:0'):      # Look up embeddings for inputs.      with tf.name_scope('embeddings'):        embeddings = tf.Variable(            tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))        embed = tf.nn.embedding_lookup(embeddings, train_inputs)

 

定义的embeddings应该是可以更新的。怎么更新?:

with tf.name_scope('loss'):      loss = tf.reduce_mean(          tf.nn.nce_loss(              weights=nce_weights,              biases=nce_biases,              labels=train_labels,              inputs=embed,              num_sampled=num_sampled,              num_classes=vocabulary_size))    # Add the loss value as a scalar to summary.    tf.summary.scalar('loss', loss)    # Construct the SGD optimizer using a learning rate of 1.0.    with tf.name_scope('optimizer'):      optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

 

使用SGD随机梯度下降,在minimize损失函数中,应该是会对所有的可训练变量求导,对的,没错一定是这样,所以nec_weights,nce_biases,embeddings都是可更新变量。

都是通过先计算损失函数,求导然后更新变量,在迭代数据计算损失函数,求导更新,

这样来更新的。

转载于:https://www.cnblogs.com/BlueBlueSea/p/10616314.html

你可能感兴趣的文章
C#学习笔记2
查看>>
LCLFramework架构必须要知道的知识
查看>>
[公益课程]Spring Boot 2.x 实战入门
查看>>
Centos7.5修改双系统启动顺序
查看>>
css 补漏
查看>>
Common ways to tell time
查看>>
C++ 无法从void 转换为 LRESULT
查看>>
[原]Unity3D深入浅出 - 光源组件(Light)
查看>>
数据库对象(视图,序列,索引,同义词)【weber出品必属精品】
查看>>
ubuntu下安装和配置java开发环境
查看>>
axios跨域问题
查看>>
测试用例-场景法
查看>>
SFDC_03(覆盖率)
查看>>
hdu 1443 Joseph 约瑟夫环
查看>>
CRM项目需求变更带来的麻烦
查看>>
JSONHelper 的摘要说明
查看>>
Swing 美化工具包
查看>>
UVa1366 Martian Mining
查看>>
指定子设备号创建字符设备
查看>>
debugs
查看>>