Skip to content

Commit

Permalink
Merge pull request #253 from ami-iit/uniform_usage_of_joint_force_ref…
Browse files Browse the repository at this point in the history
…erences

Uniform usage of `joint_forces` and `joint_force_references` argument names
  • Loading branch information
diegoferigo authored Oct 7, 2024
2 parents 5c9215e + 07b3ece commit 1bbf438
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 51 deletions.
4 changes: 2 additions & 2 deletions examples/PD_controller.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=None,\n",
" joint_force_references=None,\n",
" link_forces=None,\n",
" )\n",
"\n",
Expand Down Expand Up @@ -276,7 +276,7 @@
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=control_torques,\n",
" joint_force_references=control_torques,\n",
" link_forces=None,\n",
" )\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/Parallel_computing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=None,\n",
" joint_force_references=None,\n",
" link_forces=None,\n",
" )\n",
" x_t_i.append(data.base_position())\n",
Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,8 +1907,8 @@ def step(
dt: jtp.FloatLike,
integrator: jaxsim.integrators.Integrator,
integrator_state: dict[str, Any] | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
**kwargs,
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
"""
Expand All @@ -1920,10 +1920,10 @@ def step(
dt: The time step to consider.
integrator: The integrator to use.
integrator_state: The state of the integrator.
joint_forces: The joint forces to consider.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding to
the velocity representation of `data`.
joint_force_references: The joint force references to consider.
kwargs: Additional kwargs to pass to the integrator.
Returns:
Expand Down Expand Up @@ -1953,7 +1953,7 @@ def step(
params=integrator_state_x0,
# Always inject the current (model, data) pair into the system dynamics
# considered by the integrator, and include the input variables represented
# by the pair (joint_forces, link_forces).
# by the pair (joint_force_references, link_forces).
# Note that the wrapper of the system dynamics will override (state_x0, t0)
# inside the passed data even if it is not strictly needed. This logic is
# necessary to re-use the jit-compiled step function of compatible pytrees
Expand All @@ -1962,7 +1962,7 @@ def step(
dict(
model=model,
data=data,
joint_forces=joint_forces,
joint_force_references=joint_force_references,
link_forces=link_forces,
)
| integrator_kwargs
Expand Down
48 changes: 29 additions & 19 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,19 @@ def system_velocity_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_forces: jtp.Vector | None = None,
link_forces: jtp.Vector | None = None,
joint_force_references: jtp.Vector | None = None,
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
"""
Compute the dynamics of the system velocity.
Args:
model: The model to consider.
data: The data of the considered model.
joint_forces: The joint force references to apply.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding to
the velocity representation of `data`.
joint_force_references: The joint force references to apply.
Returns:
A tuple containing the derivative of the base 6D velocity in inertial-fixed
Expand All @@ -120,7 +120,7 @@ def system_velocity_dynamics(
references = js.references.JaxSimModelReferences.build(
model=model,
link_forces=O_f_L,
joint_force_references=joint_forces,
joint_force_references=joint_force_references,
data=data,
velocity_representation=data.velocity_representation,
)
Expand Down Expand Up @@ -192,7 +192,10 @@ def system_velocity_dynamics(
f_L_total = references.link_forces(model=model, data=data)

v̇_WB, = system_acceleration(
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
model=model,
data=data,
joint_force_references=joint_force_references,
link_forces=f_L_total,
)

return v̇_WB, , aux_data
Expand All @@ -202,21 +205,22 @@ def system_acceleration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the system acceleration in the active representation.
Args:
model: The model to consider.
data: The data of the considered model.
joint_forces: The joint forces to apply.
link_forces:
The 6D forces to apply to the links expressed in the same representation of data.
The 6D forces to apply to the links expressed in the same
velocity representation of data.
joint_force_references: The joint force references to apply.
Returns:
A tuple containing the base 6D acceleration in in the active representation
A tuple containing the base 6D acceleration in the active representation
and the joint accelerations.
"""

Expand All @@ -232,9 +236,9 @@ def system_acceleration(
).astype(float)

# Build joint torques if not provided.
τ = (
jnp.atleast_1d(joint_forces.squeeze())
if joint_forces is not None
τ_references = (
jnp.atleast_1d(joint_force_references.squeeze())
if joint_force_references is not None
else jnp.zeros_like(data.joint_positions())
).astype(float)

Expand All @@ -243,15 +247,16 @@ def system_acceleration(
# ====================

# TODO: enforce joint limits
τ_position_limit = jnp.zeros_like(τ).astype(float)
τ_position_limit = jnp.zeros_like(τ_references).astype(float)

# ====================
# Joint friction model
# ====================

τ_friction = jnp.zeros_like(τ).astype(float)
τ_friction = jnp.zeros_like(τ_references).astype(float)

if model.dofs() > 0:

# Static and viscous joint friction parameters
kc = jnp.array(
model.kin_dyn_parameters.joint_parameters.friction_static
Expand All @@ -271,22 +276,27 @@ def system_acceleration(
# ========================

# Compute the total joint forces.
τ_total = τ + τ_friction + τ_position_limit
τ_total = τ_references + τ_friction + τ_position_limit

# Store the link forces in a references object.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
velocity_representation=data.velocity_representation,
joint_force_references=τ_total,
link_forces=f_L,
)

# Compute forward dynamics.
#
# - Joint accelerations: s̈ ∈ ℝⁿ
# - Base acceleration: v̇_WB ∈ ℝ⁶
#
# Note that ABA returns the base acceleration in the velocity representation
# stored in the `data` object.
v̇_WB, = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=references.joint_force_references(model=model),
joint_forces=τ_total,
link_forces=references.link_forces(model=model, data=data),
)

Expand Down Expand Up @@ -337,8 +347,8 @@ def system_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_forces: jtp.Vector | None = None,
link_forces: jtp.Vector | None = None,
joint_force_references: jtp.Vector | None = None,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> tuple[ODEState, dict[str, Any]]:
"""
Expand All @@ -347,10 +357,10 @@ def system_dynamics(
Args:
model: The model to consider.
data: The data of the considered model.
joint_forces: The joint forces to apply.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding to
the velocity representation of `data`.
joint_force_references: The joint force references to apply.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient used to adjust the norm of the
quaternion (only used in integrators not operating on the SO(3) manifold).
Expand All @@ -368,7 +378,7 @@ def system_dynamics(
W_v̇_WB, , aux_dict = system_velocity_dynamics(
model=model,
data=data,
joint_forces=joint_forces,
joint_force_references=joint_force_references,
link_forces=link_forces,
)

Expand Down
34 changes: 18 additions & 16 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ class ODEInput(JaxsimDataclass):
@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> ODEInput:
"""
Build an `ODEInput` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the ODE input.
joint_forces: The vector of joint forces.
link_forces: The matrix of external forces applied to the links.
joint_force_references: The vector of joint force references.
Returns:
The `ODEInput` built from the `JaxSimModel`.
Expand All @@ -60,8 +60,8 @@ def build_from_jaxsim_model(
return ODEInput.build(
physics_model_input=PhysicsModelInput.build_from_jaxsim_model(
model=model,
joint_forces=joint_forces,
link_forces=link_forces,
joint_force_references=joint_force_references,
),
model=model,
)
Expand Down Expand Up @@ -526,16 +526,16 @@ class PhysicsModelInput(JaxsimDataclass):
@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> PhysicsModelInput:
"""
Build a `PhysicsModelInput` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the input.
joint_forces: The vector of joint forces.
link_forces: The matrix of external forces applied to the links.
joint_force_references: The vector of joint force references.
Returns:
A `PhysicsModelInput` instance.
Expand All @@ -546,45 +546,47 @@ def build_from_jaxsim_model(
"""

return PhysicsModelInput.build(
joint_forces=joint_forces,
joint_force_references=joint_force_references,
link_forces=link_forces,
number_of_dofs=model.dofs(),
number_of_links=model.number_of_links(),
)

@staticmethod
def build(
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
number_of_dofs: jtp.Int | None = None,
number_of_links: jtp.Int | None = None,
) -> PhysicsModelInput:
"""
Build a `PhysicsModelInput`.
Args:
joint_forces: The vector of joint forces.
link_forces: The matrix of external forces applied to the links.
joint_force_references: The vector of joint force references.
number_of_dofs: The number of degrees of freedom of the model.
number_of_links: The number of links of the model.
Returns:
A `PhysicsModelInput` instance.
"""

joint_forces = (
joint_forces if joint_forces is not None else jnp.zeros(number_of_dofs)
)
joint_force_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
else jnp.zeros(number_of_dofs)
).astype(float)

link_forces = (
link_forces
link_forces = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros(shape=(number_of_links, 6))
)
).astype(float)

return PhysicsModelInput(
tau=jnp.array(joint_forces, dtype=float),
f_ext=jnp.array(link_forces, dtype=float),
tau=joint_force_references,
f_ext=link_forces,
)

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def zero(
@staticmethod
def build(
model: js.model.JaxSimModel,
joint_force_references: jtp.Vector | None = None,
link_forces: jtp.Matrix | None = None,
joint_force_references: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
data: js.data.JaxSimModelData | None = None,
velocity_representation: VelRepr | None = None,
) -> JaxSimModelReferences:
Expand All @@ -78,14 +78,14 @@ def build(

# Create or adjust joint force references.
joint_force_references = jnp.atleast_1d(
joint_force_references.squeeze()
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
else jnp.zeros(model.dofs())
).astype(float)

# Create or adjust link forces.
f_L = jnp.atleast_2d(
link_forces.squeeze()
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
).astype(float)
Expand Down Expand Up @@ -299,9 +299,9 @@ def set_joint_force_references(
A new `JaxSimModelReferences` object with the given joint force references.
"""

forces = jnp.array(forces)
forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze())

def replace(forces: jtp.VectorLike) -> JaxSimModelReferences:
def replace(forces: jtp.Vector) -> JaxSimModelReferences:
return self.replace(
validate=True,
input=self.input.replace(
Expand Down
4 changes: 3 additions & 1 deletion src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
joint_forces=references.joint_force_references(model=model),
joint_force_references=references.joint_force_references(
model=model
),
)
)
BW_ν = data.generalized_velocity()
Expand Down
Loading

0 comments on commit 1bbf438

Please sign in to comment.