Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Jax support #8

Closed
wants to merge 4 commits into from
Closed

[Draft] Jax support #8

wants to merge 4 commits into from

Conversation

joaogui1
Copy link

@joaogui1 joaogui1 commented Nov 3, 2019

So, right now when running the test_docstring_example test I get an error, here's the stack trace

  File "cvxpylayers/jax/test_cvxpylayer.py", line 65, in test_docstring_example
    gradA = jax.grad(summed_solution)(A_tf, b_tf)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/jax/api.py", line 341, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/jax/api.py", line 387, in value_and_grad_f
    ans, vjp_py = vjp(f_partial, *dyn_args)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/jax/api.py", line 1002, in vjp
    out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/jax/interpreters/ad.py", line 105, in vjp
    out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/jax/interpreters/ad.py", line 94, in linearize
    jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 400, in trace_to_jaxpr
    jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "cvxpylayers/jax/test_cvxpylayer.py", line 63, in summed_solution
    solution = cvxpylayer(x, y)
  File "/home/john/Documents/Programming/cvxpylayers/cvxpylayers/jax/cvxpylayer.py", line 103, in __call__
    compute, compute_grad = self._compute_and_grad(parameters, solver_args)
  File "/home/john/Documents/Programming/cvxpylayers/cvxpylayers/jax/cvxpylayer.py", line 158, in _compute_and_grad
    A, b, c = self._problem_data_from_params(params)
  File "/home/john/Documents/Programming/cvxpylayers/cvxpylayers/jax/cvxpylayer.py", line 120, in _problem_data_from_params
    params)))
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/cvxpy/reductions/dcp2cone/cone_matrix_stuffing.py", line 94, in apply_parameters
    zero_offset=zero_offset)
  File "/home/john/anaconda3/envs/cvx-test/lib/python3.7/site-packages/cvxpy/cvxcore/python/canonInterface.py", line 57, in get_parameter_vector
    param_vec[col:col + size] = value
TypeError: __array__() takes 1 positional argument but 2 were given

@gsp-27, can you give me some help with this?

@joaogui1
Copy link
Author

joaogui1 commented Nov 5, 2019

Hey @bamos, do you have any idea what this error could be?

@bamos
Copy link
Collaborator

bamos commented Nov 5, 2019

Hey @bamos, do you have any idea what this error could be?

Hmm, can you check what the parameter values you're passing down into cvxpy are? They should be numpy arrays that match the shape of the parameters.

One further way of debugging this would be to set up the same problem in PyTorch/TF and compare what's happening at this point in the code between those and here

@sbarratt
Copy link
Collaborator

sbarratt commented Nov 5, 2019

I think you need to do something along these lines: jax-ml/jax#1142 (comment)

@sbarratt
Copy link
Collaborator

sbarratt commented Nov 5, 2019

Also, we will need to disable JIT around our function: https://github.com/google/jax/blob/master/jax/api.py#L157

@joaogui1 joaogui1 closed this Dec 14, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants