您现在的位置是:首页 >其他 >NTM中attr的用法网站首页其他

NTM中attr的用法

phac123 2024-06-26 14:23:49
简介NTM中attr的用法

代码1

@attrs
class CopyTaskParams(object):
    name = attrib(default="copy-task")
    controller_size = attrib(default=100, convert=int)
    controller_layers = attrib(default=1,convert=int)
    num_heads = attrib(default=1, convert=int)
    sequence_width = attrib(default=8, convert=int)
    sequence_min_len = attrib(default=1,convert=int)
    sequence_max_len = attrib(default=20, convert=int)
    memory_n = attrib(default=128, convert=int)
    memory_m = attrib(default=20, convert=int)
    num_batches = attrib(default=50000, convert=int)
    batch_size = attrib(default=1, convert=int)
    rmsprop_lr = attrib(default=1e-4, convert=float)
    rmsprop_momentum = attrib(default=0.9, convert=float)
    rmsprop_alpha = attrib(default=0.95, convert=float)

解析

这段代码是使用Python的attrs库定义一个class,名为CopyTaskParams。这个class的对象包含了一系列的参数,用于控制复制任务模型的行为。下面解释一下每个参数的含义:

  • name: str,复制任务的名称,默认值为"copy-task"。
  • controller_size: int,控制器的大小,默认值为100。
  • controller_layers: int,控制器的层数,默认值为1。
  • num_heads: int,注意力机制中头的数量,默认值为1。
  • sequence_width: int,序列的宽度,默认值为8。
  • sequence_min_len: int,序列的最小长度,默认值为1。
  • sequence_max_len: int,序列的最大长度,默认值为20。
  • memory_n: int,记忆矩阵的行数,默认值为128。
  • memory_m: int,记忆矩阵的列数,默认值为20。
  • num_batches: int,训练时的批次数,默认值为50000。
  • batch_size: int,每个批次的大小,默认值为1。
  • rmsprop_lr: float,RMSProp中的学习率,默认值为1e-4。
  • rmsprop_momentum: float,RMSProp中的动量值,默认值为0.9。
  • rmsprop_alpha: float,RMSProp中的alpha值,默认值为0.95。

这个类是使用 Python 库 attrs 中的装饰器 attrib 定义的,该装饰器使用方法类似于 Python 标准库中的 property 装饰器。attrib 装饰器帮助我们自动生成实例变量,init 方法以及默认参数值等等,并提供验证和转换等功能。
通过使用 attrib 装饰器,我们可以在属性中设置 default 参数,来指定属性的初始值。convert 参数则指定该属性的类型转换方法。例如,convert=int 在将其赋为整数值之前,会尝试将其转换为整数类型。
另外,值得注意的是,在类中未定义 strrepr 方法时,它们将使用 attrs 帮助我们自动生成,以便在实例被打印时很好地显示。
总之,通过使用 attrs,有助于减少模板代码的编写和维护。同时,它还提供了很多其他有用的功能,如比较实例,填写数据缺失值等等。

代码2

import attr

@attr.s
class Point:
    x = attr.ib(default=0)
    y = attr.ib(default=0)
p1 = Point(x=1, y=2)
p2 = attr.evolve(p1, x=3)
p1 = attr.evolve(p1, x=3)

attr.evolve 是 attrs 库中的一个函数,其作用是创建一个原始对象的副本,并替换其中的一些属性值。它的函数签名如下:

attr.evolve(inst, **changes)

其中 inst 是需要进行修改的原始对象实例,changes 是一个字典,用于指定需要修改的属性和对应的新值。函数返回值是一个新生成的对象实例。
具体来说,changes 的键是需要进行修改的属性名,值是对应的新值。例如,假设有一个 Point 类,它用来表示二维平面上的点坐标。

  • 现在我们创建了一个名为 p1 的 Point 对象实例。
  • 如果我们需要修改这个对象实例中的 x 属性,可以使用 evolve 方法。例如,要把它的 x 属性值从 1 改为 3。
  • 这样,p2 对象实例的 x 属性值就被修改为了 3,而 y 属性的值保持原来不变。
  • 需要注意的是,原始对象实例 p1 的值并没有发生改变,evolve 方法并不会修改原始对象,而是生成一个新的对象实例。如果希望将原始对象实例也进行修改,要对p1操作。
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。