Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider only enabled collidable points in contact forces computation for Rigid, RelaxedRigid and Soft contact models #274

Merged
merged 14 commits into from
Nov 14, 2024

Conversation

xela-95
Copy link
Member

@xela-95 xela-95 commented Oct 24, 2024

In this PR I'm adding the possibility to consider only a subset of the available collidable points of the model in the computation of the contact forces.

I'm extending the behavior already introduced in #248 for ViscoElasticContacts to

  • RigidContacts
  • RelaxedRigidContacts
  • SoftContacts

The modifications in this PR have the intention of be sufficiently general to easily extend this also to SoftContacts.


📚 Documentation preview 📚: https://jaxsim--274.org.readthedocs.build//274/

@xela-95 xela-95 self-assigned this Oct 24, 2024
@xela-95 xela-95 changed the title Feature/extend contact enabled mask Consider only enabled collidable points in computation of contact forces for Rigid and RelaxedRigid contact models Oct 24, 2024
@xela-95 xela-95 changed the title Consider only enabled collidable points in computation of contact forces for Rigid and RelaxedRigid contact models Consider only enabled collidable points in contact forces computation for Rigid and RelaxedRigid contact models Oct 24, 2024
@xela-95
Copy link
Member Author

xela-95 commented Oct 24, 2024

@diegoferigo I have some questions about this work:

I tried to extend this work to SoftContacts but I'm getting errors due to the integration of the tangential deformation rate $\dot{m}$. I checked your implementation in ViscoElasticContacts, but in that case the integration is part of the contact model and only the tangential deformation related to the enabled collidable points is updated:

data_tf.state.extended |= {
"tangential_deformation": data_tf.state.extended["tangential_deformation"]
.at[indices_of_enabled_collidable_points]
.set(m_tf)
}

While if I do a similar thing with SoftContacts I get the following error:


File ~/repos/jaxsim/src/jaxsim/api/model.py:2013, in step(model, data, integrator, t0, dt, integrator_state, link_forces, joint_force_references, **kwargs)
   2010     τ_references = references.joint_force_references(model=model)
   2012 # Step the dynamics forward.
-> 2013 state_tf, integrator_state_tf = integrator.step(
   2014     x0=state_t0,
   2015     t0=t0,
   2016     dt=dt,
   2017     params=integrator_state_t0,
   2018     # Always inject the current (model, data) pair into the system dynamics
   2019     # considered by the integrator, and include the input variables represented
   2020     # by the pair (f_L, τ_references).
   2021     # Note that the wrapper of the system dynamics will override (state_x0, t0)
   2022     # inside the passed data even if it is not strictly needed. This logic is
   2023     # necessary to re-use the jit-compiled step function of compatible pytrees
   2024     # of model and data produced e.g. by parameterized applications.
   2025     **(
   2026         dict(
   2027             model=model,
   2028             data=data,
   2029             link_forces=f_L,
   2030             joint_force_references=τ_references,
   2031         )
   2032         | integrator_kwargs
   2033     ),
   2034 )
   2036 # Store the new state of the model.
   2037 data_tf = data.replace(state=state_tf)

File ~/repos/jaxsim/src/jaxsim/integrators/common.py:112, in Integrator.step(self, x0, t0, dt, params, **kwargs)
    109     integrator.params = params
    111 with integrator.mutable_context(mutability=Mutability.MUTABLE):
--> 112     xf, aux_dict = integrator(x0, t0, dt, **kwargs)
    114 return (
    115     xf,
    116     integrator.params
    117     | {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
    118     | aux_dict,
    119 )

File ~/repos/jaxsim/src/jaxsim/integrators/common.py:288, in ExplicitRungeKutta.__call__(self, x0, t0, dt, **kwargs)
    281 def __call__(
    282     self, x0: State, t0: Time, dt: TimeStep, **kwargs
    283 ) -> tuple[NextState, dict[str, Any]]:
   (...)
    286     # Note that z has multiple batches only if b.T has more than one row,
    287     # e.g. in Butcher tableau of embedded schemes.
--> 288     z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
    290     # The next state is the batch element located at the configured index of solution.
    291     next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z)

File ~/repos/jaxsim/src/jaxsim/integrators/common.py:426, in ExplicitRungeKutta._compute_next_state(self, x0, t0, dt, **kwargs)
    423     return carry, aux_dict
    425 # Compute the state derivatives kᵢ.
--> 426 K, aux_dict = jax.lax.scan(
    427     f=scan_body,
    428     init=carry0,
    429     xs=jnp.arange(c.size),
    430 )
    432 # Update the FSAL property for the next iteration.
    433 if self.has_fsal:

    [... skipping hidden 9 frame]

File ~/repos/jaxsim/src/jaxsim/integrators/common.py:420, in ExplicitRungeKutta._compute_next_state.<locals>.scan_body(carry, i)
    418 # Store the kᵢ derivative in K.
    419 op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
--> 420 K = jax.tree.map(op, K, ki)
    422 carry = K
    423 return carry, aux_dict

File ~/mambaforge/envs/comododev/lib/python3.10/site-packages/jax/_src/tree.py:155, in map(f, tree, is_leaf, *rest)
    115 def map(f: Callable[..., Any],
    116         tree: Any,
    117         *rest: Any,
    118         is_leaf: Callable[[Any], bool] | None = None) -> Any:
    119   """Maps a multi-input function over pytree args to produce a new pytree.
    120
    121   Args:
   (...)
    153     - :func:`jax.tree.reduce`
    154   """
--> 155   return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)

    [... skipping hidden 2 frame]

File ~/repos/jaxsim/src/jaxsim/integrators/common.py:419, in ExplicitRungeKutta._compute_next_state.<locals>.scan_body.<locals>.<lambda>(l_k, l_ki)
    412 ki, aux_dict = jax.lax.cond(
    413     pred=jnp.logical_and(i == 0, self.has_fsal),
    414     true_fun=get_ẋ0_and_aux_dict,
    415     false_fun=compute_ki,
    416 )
    418 # Store the kᵢ derivative in K.
--> 419 op = lambda l_k, l_ki: l_k.at[i].set(l_ki)
    420 K = jax.tree.map(op, K, ki)
    422 carry = K

File ~/mambaforge/envs/comododev/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:788, in _IndexUpdateRef.set(self, values, indices_are_sorted, unique_indices, mode)
    779 def set(self, values, *, indices_are_sorted=False, unique_indices=False,
    780         mode=None):
    781   """Pure equivalent of ``x[idx] = y``.
    782
    783   Returns the value of ``x`` that would result from the NumPy-style
   (...)
    786   See :mod:`jax.ops` for details.
    787   """
--> 788   return scatter._scatter_update(self.array, self.index, values, lax.scatter,
    789                                  indices_are_sorted=indices_are_sorted,
    790                                  unique_indices=unique_indices, mode=mode)

File ~/mambaforge/envs/comododev/lib/python3.10/site-packages/jax/_src/ops/scatter.py:76, in _scatter_update(x, idx, y, scatter_op, indices_are_sorted, unique_indices, mode, normalize_indices)
     73 # XLA gathers and scatters are very similar in structure; the scatter logic
     74 # is more or less a transpose of the gather equivalent.
     75 treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
---> 76 return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
     77                      indices_are_sorted, unique_indices, mode,
     78                      normalize_indices)

File ~/mambaforge/envs/comododev/lib/python3.10/site-packages/jax/_src/ops/scatter.py:111, in _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, normalize_indices)
    108 x, y = promote_dtypes(x, y)
    110 # Broadcast `y` to the slice output shape.
--> 111 y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    112 # Collapse any `None`/`jnp.newaxis` dimensions.
    113 y = jnp.squeeze(y, axis=indexer.newaxis_dims)

File ~/mambaforge/envs/comododev/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2882, in broadcast_to(array, shape)
   2848 def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
   2849   """Broadcast an array to a specified shape.
   2850
   2851   JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style
   (...)
   2880   .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
   2881   """
-> 2882   return util._broadcast_to(array, shape)

File ~/mambaforge/envs/comododev/lib/python3.10/site-packages/jax/_src/numpy/util.py:406, in _broadcast_to(arr, shape)
    404 if nlead < 0 or not compatible:
    405   msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
--> 406   raise ValueError(msg.format(arr_shape, shape))
    407 return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))))

ValueError: Incompatible shapes for broadcasting: (16, 3) and requested shape (48, 3)

(where in this example, 16 is the number of enabled collidable points and 48 is their total number).

Have you any idea on how to address this?

@xela-95
Copy link
Member Author

xela-95 commented Oct 24, 2024

@diegoferigo do you think is useful to add some unit tests for this? If yes, do you suggest to update some of the tests already in place or to perform some other simulation tests with a reduced number of collidable points?

@flferretti
Copy link
Collaborator

My guess is that the integrator has been initialized passing a dynamics that is closed on a model with a different shape of contact_parameters.point, you should re-create the integrator or try with #252

@diegoferigo
Copy link
Member

I tried to extend this work to SoftContacts but I'm getting errors due to the integration of the tangential deformation rate $\dot{m}$. I checked your implementation in ViscoElasticContacts, but in that case the integration is part of the contact model and only the tangential deformation related to the enabled collidable points is updated:

data_tf.state.extended |= {
"tangential_deformation": data_tf.state.extended["tangential_deformation"]
.at[indices_of_enabled_collidable_points]
.set(m_tf)
}

While if I do a similar thing with SoftContacts I get the following error:
[...]

(where in this example, 16 is the number of enabled collidable points and 48 is their total number).

Have you any idea on how to address this?

I start with a remark. Contrarily to state-less contact models like RigidContacts and RelaxedRigidContacts, our soft-like models (like SoftContacts and ViscoElasticContacts) operate on an additional state $\mathbf{m}$ that represent the tangential deformation of the material corresponding to individual contact points.

Filtering either statically or dynamically the active contact points of stateless models is reasonable1. This is because when a active point becomes inactive, or another contact point is considered instead, there's no state to handle. This is no longer true in soft-like models.

As a consequence, I believe that dynamic filtering of active contact points of soft-like models is something that we cannot achieve (cc @flferretti). However, static filtering can be supported, and this PR will only implement this type of filtering. This being said, I think that you need to initialize the tangential deformation variables stored in JaxSimData with the right shape (and I mean, with as many rows as enabled collidable points).

do you think is useful to add some unit tests for this? If yes, do you suggest to update some of the tests already in place or to perform some other simulation tests with a reduced number of collidable points?

In the meantime, you can maybe filter 3 out of 4 of the bottom corners of the falling box simulation with different contact models. We can discuss this better later.

Footnotes

  1. Possibly while maintaining the same number of considered contacts to avoid jit recompilations.

@xela-95
Copy link
Member Author

xela-95 commented Oct 24, 2024

I start with a remark. Contrarily to state-less contact models like RigidContacts and RelaxedRigidContacts, our soft-like models (like SoftContacts and ViscoElasticContacts) operate on an additional state m that represent the tangential deformation of the material corresponding to individual contact points.

Filtering either statically or dynamically the active contact points of stateless models is reasonable1. This is because when a active point becomes inactive, or another contact point is considered instead, there's no state to handle. This is no longer true in soft-like models.

As a consequence, I believe that dynamic filtering of active contact points of soft-like models is something that we cannot achieve (cc @flferretti). However, static filtering can be supported, and this PR will only implement this type of filtering. This being said, I think that you need to initialize the tangential deformation variables stored in JaxSimData with the right shape (and I mean, with as many rows as enabled collidable points).

Thanks a lot for the clear remarks!

@xela-95 xela-95 force-pushed the feature/extend-contact-enabled-mask branch from c61bdc5 to 1c517cb Compare October 29, 2024 11:56
@xela-95
Copy link
Member Author

xela-95 commented Oct 29, 2024

In the end by just updating tangential deformations just for the enabled collidable points, without changing the shape of the array of this state the static filtering of collidable points also worked for Soft Contacts.

@xela-95 xela-95 marked this pull request as ready for review October 29, 2024 12:00
@xela-95 xela-95 requested a review from flferretti as a code owner October 29, 2024 12:00
Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @xela-95! I just left a couple minor comments

src/jaxsim/api/contact.py Show resolved Hide resolved
src/jaxsim/rbda/contacts/relaxed_rigid.py Outdated Show resolved Hide resolved
@xela-95 xela-95 changed the title Consider only enabled collidable points in contact forces computation for Rigid and RelaxedRigid contact models Consider only enabled collidable points in contact forces computation for Rigid, RelaxedRigid and Soft contact models Oct 29, 2024
@xela-95 xela-95 force-pushed the feature/extend-contact-enabled-mask branch from 1c517cb to d14636b Compare November 5, 2024 13:13
@xela-95 xela-95 requested a review from diegoferigo November 5, 2024 13:14
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xela-95 for your work! I'm leaving a general comment to trigger a discussion, I'm not sure what's the best approach.

This PR implements the filtering of the collidable points at the contact model level. Furthermore, to streamline this approach for integration purpose, it also updates jaxsim.api.contact.collidable_point_dynamics to return the contact forces of only the enabled collidable points. This last change creates a disparity with all the other APIs belonging to jaxsim.api.contact, that will keep returning data corresponding to all the collidable points.

I'm not sure here how to proceed. There are a few ways:

  1. Compute the contact forces of all collidable points and slice the output of collidable_point_dynamics at the caller level, as done for other contact-related quantities. This will however defeat any performance boost that would result from the filtering.
  2. Return a full matrix of contact forces and aux_data, having zero content for disabled collidable points.
  3. Adjust all the other APIs to return quantities corresponding to only the enabled collidable points.

I'm not sure which one is the best at first sight. For the time being, to reduce the work, I'd tend towards 2 and possibly give 3 a try in the future. But I'd like to discuss first to know your opinion.

@xela-95
Copy link
Member Author

xela-95 commented Nov 5, 2024

Thanks @xela-95 for your work! I'm leaving a general comment to trigger a discussion, I'm not sure what's the best approach.

This PR implements the filtering of the collidable points at the contact model level. Furthermore, to streamline this approach for integration purpose, it also updates jaxsim.api.contact.collidable_point_dynamics to return the contact forces of only the enabled collidable points. This last change creates a disparity with all the other APIs belonging to jaxsim.api.contact, that will keep returning data corresponding to all the collidable points.

I'm not sure here how to proceed. There are a few ways:

  1. Compute the contact forces of all collidable points and slice the output of collidable_point_dynamics at the caller level, as done for other contact-related quantities. This will however defeat any performance boost that would result from the filtering.
  2. Return a full matrix of contact forces and aux_data, having zero content for disabled collidable points.
  3. Adjust all the other APIs to return quantities corresponding to only the enabled collidable points.

I'm not sure which one is the best at first sight. For the time being, to reduce the work, I'd tend towards 2 and possibly give 3 a try in the future. But I'd like to discuss first to know your opinion

Thanks a lot for the feedback @diegoferigo ! Yes you're right this PR could led to a non obvious disparity in the jaxsim.api.contact API usage from the user perspective. I also agree that for now we could always return matrices from the contact APIs having the shape of the number of total collidable points, with zero elements apart the ones related to the enabled collidable points.

Ideally form the user perspective I think it will be more useful (and less of a hassle) to get always quantities related only to enabled collidable points, but this behavior could have impacts that we should test better, maybe in a separate PR.

I will update the current PR with the solution 2) for the time being. i.e. updating all contact APIs accordingly to return non zero data only for the enabled collidable points.

@diegoferigo
Copy link
Member

Ideally form the user perspective I think it will be more useful (and less of a hassle) to get always quantities related only to enabled collidable points, but this behavior could have impacts that we should test better, maybe in a separate PR.

My concerns on the extended filter is that the filtering in the future can be useful to dynamically select a subset of points (as opposed to now that is hardcoded). Being the mask a static attribute, if it changes automatically, there will be jit recompilation. We might work around it with jax.lax.dynamic_slice1, but I've never use it that much to be confident that this solution is what we need here. This should be properly evaluated before started working on option 3. Regardless, a dynamic slice would still require to compute the the larger matrix, and this would bring no/just little performance benefits. I agree to start with 2, and start thinking whether if 3 is really worth.

Footnotes

  1. https://stackoverflow.com/a/76628452

@flferretti
Copy link
Collaborator

flferretti commented Nov 5, 2024

For what regards:

  1. Adjust all the other APIs to return quantities corresponding to only the enabled collidable points.

The lines to be modified are:

api.contact.collidable_point_dynamics

api.contact.jacobian

api.contact.jacobian_derivative

  • parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
    L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)

api.contact.transforms

api.contact.in_contact

rbda.collidable_points.collidable_points_pos_vel

And then reset the modifications done inside rbda.contact.* as the previous modifications will naturally be propagated

It seems doable, WDYT?

@xela-95
Copy link
Member Author

xela-95 commented Nov 5, 2024

My concerns on the extended filter is that the filtering in the future can be useful to dynamically select a subset of points (as opposed to now that is hardcoded). Being the mask a static attribute, if it changes automatically, there will be jit recompilation. We might work around it with jax.lax.dynamic_slice1, but I've never use it that much to be confident that this solution is what we need here. This should be properly evaluated before started working on option 3. Regardless, a dynamic slice would still require to compute the the larger matrix, and this would bring no/just little performance benefits. I agree to start with 2, and start thinking whether if 3 is really worth.

Yes for sure before to start considering the idea of dynamic slicing we should investigate the possible performance outcomes deriving from the issues you're describing.

It seems doable, WDYT?

Thanks for the contribution @flferretti ! For me, this is doable with the static filtering we adopted in this PR. WRT the considerations expressed by @diegoferigo these modifications could produce a performance increase for now, that maybe should be changed again when/if we'll address dynamic filtering. I can give it a try since it should be doable quite quickly.

@xela-95
Copy link
Member Author

xela-95 commented Nov 13, 2024

For what regards:

  1. Adjust all the other APIs to return quantities corresponding to only the enabled collidable points.

The lines to be modified are:

api.contact.collidable_point_dynamics

api.contact.jacobian

api.contact.jacobian_derivative

  • parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
    L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)

api.contact.transforms

api.contact.in_contact

rbda.collidable_points.collidable_points_pos_vel

And then reset the modifications done inside rbda.contact.* as the previous modifications will naturally be propagated

It seems doable, WDYT?

Done, now the jaxsim.api.contact and jaxsim.rbda.collidable_points APIs return results for the enabled collidable points. I adapted the unit tests to deal with it correctly.

@xela-95 xela-95 force-pushed the feature/extend-contact-enabled-mask branch from f0d2726 to 593d9bd Compare November 13, 2024 16:33
@xela-95
Copy link
Member Author

xela-95 commented Nov 13, 2024

Rebased on main

@xela-95 xela-95 force-pushed the feature/extend-contact-enabled-mask branch from 593d9bd to c531bbd Compare November 14, 2024 08:02
@xela-95 xela-95 merged commit d5c06bb into ami-iit:main Nov 14, 2024
12 of 13 checks passed
@xela-95 xela-95 deleted the feature/extend-contact-enabled-mask branch November 14, 2024 08:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants