Skip to content

Commit

Permalink
Add DirichletFullyConnectedActor to Soft Actor-Critic
Browse files Browse the repository at this point in the history
Summary: This can be used for problems where the action vector must sum to 1

Reviewed By: kittipatv

Differential Revision: D15206348

fbshipit-source-id: 665fbed893d8c52d451a12d3bb2e73b2638b7963
  • Loading branch information
econti authored and facebook-github-bot committed May 7, 2019
1 parent 4ca325d commit 2356fac
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
13 changes: 13 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,19 @@ def forward(self, input):
x = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9)
self.run_model_test(Log2Model(), train=False, input=x, batch_size=BATCH_SIZE)

def test_sample_dirichlet(self):
class DirichletModel(torch.nn.Module):
def forward(self, input):
return torch._sample_dirichlet(input)

x = torch.randn(2, 3, 4, requires_grad=False)
model = DirichletModel()
onnxir, _ = do_export(model, x)
onnx_model = onnx.ModelProto.FromString(onnxir)
prepared = c2.prepare(onnx_model)
caffe2_out = prepared.run(inputs=[x.cpu().numpy()])
self.assertEqual(caffe2_out[0].shape, x.shape)

def test_prim_shape(self):
x = torch.randn(4, 5, requires_grad=True)
@torch.jit.script
Expand Down
7 changes: 7 additions & 0 deletions torch/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ def cumsum(g, input, dim):
return g.op("ATen", input, operator_s="cumsum", dim_i=dim)


def _sample_dirichlet(g, self, generator):
if not generator.node().mustBeNone():
return _unimplemented('_sample_dirichlet',
'We are not able to export generator')
return g.op("ATen", self, operator_s="_sample_dirichlet")


def t(g, self):
return g.op("Transpose", self, perm_i=(1, 0))

Expand Down

0 comments on commit 2356fac

Please sign in to comment.