Skip to content

Commit 2356fac

Browse files
econtifacebook-github-bot
authored andcommitted
Add DirichletFullyConnectedActor to Soft Actor-Critic
Summary: This can be used for problems where the action vector must sum to 1 Reviewed By: kittipatv Differential Revision: D15206348 fbshipit-source-id: 665fbed893d8c52d451a12d3bb2e73b2638b7963
1 parent 4ca325d commit 2356fac

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,19 @@ def forward(self, input):
16061606
x = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9)
16071607
self.run_model_test(Log2Model(), train=False, input=x, batch_size=BATCH_SIZE)
16081608

1609+
def test_sample_dirichlet(self):
1610+
class DirichletModel(torch.nn.Module):
1611+
def forward(self, input):
1612+
return torch._sample_dirichlet(input)
1613+
1614+
x = torch.randn(2, 3, 4, requires_grad=False)
1615+
model = DirichletModel()
1616+
onnxir, _ = do_export(model, x)
1617+
onnx_model = onnx.ModelProto.FromString(onnxir)
1618+
prepared = c2.prepare(onnx_model)
1619+
caffe2_out = prepared.run(inputs=[x.cpu().numpy()])
1620+
self.assertEqual(caffe2_out[0].shape, x.shape)
1621+
16091622
def test_prim_shape(self):
16101623
x = torch.randn(4, 5, requires_grad=True)
16111624
@torch.jit.script

torch/onnx/symbolic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,13 @@ def cumsum(g, input, dim):
423423
return g.op("ATen", input, operator_s="cumsum", dim_i=dim)
424424

425425

426+
def _sample_dirichlet(g, self, generator):
427+
if not generator.node().mustBeNone():
428+
return _unimplemented('_sample_dirichlet',
429+
'We are not able to export generator')
430+
return g.op("ATen", self, operator_s="_sample_dirichlet")
431+
432+
426433
def t(g, self):
427434
return g.op("Transpose", self, perm_i=(1, 0))
428435

0 commit comments

Comments
 (0)