Skip to content

Commit

Permalink
Merge pull request #252 from ranqiu92/mt_with_external_memory
Browse files Browse the repository at this point in the history
update README.md and train.py
  • Loading branch information
lcy-seso authored Sep 16, 2017
2 parents c754dbe + f456031 commit 2c0e478
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions mt_with_external_memory/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ExternalMemory(object):
name,
mem_slot_size,
boot_layer,
initial_weight,
readonly=False,
enable_interpolation=True):
""" Initialization.
Expand All @@ -154,6 +155,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
Expand Down Expand Up @@ -205,7 +208,7 @@ class ExternalMemory(object):
- `_content_addressing`: 通过基于内容的寻址,计算得到读写操作的寻址强度。
- `_interpolation`: 通过插值寻址(当前寻址强度和上一时间步寻址强度的线性加权),更新当前寻址强度。
- `_get_addressing_weight`: 调用上述两个寻址操作,获得对存储导员的读写操作的最终寻址强度
- `_get_addressing_weight`: 调用上述两个寻址操作,获得对存储单元的读写操作的最终寻址强度
对外接口包含:
Expand All @@ -214,6 +217,7 @@ class ExternalMemory(object):
- 输入参数 `name`: 外部记忆单元名,不同实例的相同命名将共享同一外部记忆单元。
- 输入参数 `mem_slot_size`: 单个记忆槽(向量)的维度。
- 输入参数 `boot_layer`: 用于内存槽初始化的层。需为序列类型,序列长度表明记忆槽的数量。
- 输入参数 `initial_weight`: 用于初始化寻址强度。
- 输入参数 `readonly`: 是否打开只读模式(例如打开只读模式,该实例可用于注意力机制)。打开只读模式,`write` 方法不可被调用。
- 输入参数 `enable_interpolation`: 是否允许插值寻址(例如当用于注意力机制时,需要关闭插值寻址)。
- `write`: 写操作。
Expand All @@ -230,7 +234,6 @@ class ExternalMemory(object):
self.external_memory = paddle.layer.memory(
name=self.name,
size=self.mem_slot_size,
is_seq=True,
boot_layer=boot_layer)
```
- `ExternalMemory`类的寻址逻辑通过 `_content_addressing` 和 `_interpolation` 两个私有方法实现。读和写操作通过 `read` 和 `write` 两个函数实现,包括上述的寻址操作。并且读和写的寻址独立进行,不同于 \[[2](#参考文献)\] 中的二者共享同一个寻址强度,目的是为了使得该类更通用。
Expand Down Expand Up @@ -349,6 +352,7 @@ def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
initial_weight=unbounded_memory_weight_init,
readonly=True,
enable_interpolation=False)
```
Expand All @@ -359,6 +363,7 @@ def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
initial_weight=bounded_memory_weight_init,
readonly=False,
enable_interpolation=True)
```
Expand Down
2 changes: 1 addition & 1 deletion mt_with_external_memory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def event_handler(event):
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_batch_reader, feeding=feeding)
print "Pass: %d, TestCost: %f, %s" % (event.pass_id, event.cost,
print "Pass: %d, TestCost: %f, %s" % (event.pass_id, result.cost,
result.metrics)
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
Expand Down

0 comments on commit 2c0e478

Please sign in to comment.