Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Forward Dynamics computation by adding motor dynamics #62

Merged
merged 43 commits into from
Dec 6, 2023
Merged
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
df52281
Add motor methods in high_level.joint
flferretti Jul 20, 2023
40bb04d
Set motor parameters in high_level.model
flferretti Jul 20, 2023
e6c7410
Add motor parameters in parsers.descriptions.joint
flferretti Jul 20, 2023
64cadbc
Add motor parameters in physics.model.physics_model
flferretti Jul 20, 2023
81ae493
Add test script for aba
flferretti Jul 20, 2023
ac75b67
Update documentation
flferretti Aug 1, 2023
9a622a6
Extract motor inertias
flferretti Aug 1, 2023
ac5ab0f
Add motor inertias in first backward pass
flferretti Aug 1, 2023
430c641
Add motor viscous friction in forward pass
flferretti Aug 1, 2023
25b9dd4
Add motor parameters initialization if the model has no motors
flferretti Aug 1, 2023
56487df
Fix minor typos and delete test_aba
flferretti Aug 1, 2023
dedb032
Compute FD with CRBA considering motor parameters
flferretti Aug 1, 2023
85381c6
Approximate ABA solution considering motor inertias as linear compone…
flferretti Aug 3, 2023
c76a0bf
Add test for forward dynamics with motor parameters
flferretti Aug 3, 2023
22ebf2d
Add dirty fix
flferretti Aug 11, 2023
5768727
Disable `jax.jit` during motors test
flferretti Sep 5, 2023
323b5f2
Add `has_motors` attribute to model
flferretti Sep 5, 2023
4288f3a
Add motor parameters in FD computation with RNEA
flferretti Sep 5, 2023
97c6e0f
Fix formatting
flferretti Sep 5, 2023
3ec90ef
Update ABA and fix formatting
flferretti Sep 5, 2023
203b09b
Update motor dynamics test with latest rebase
flferretti Sep 7, 2023
b42cc60
Rename variables to improve readability
flferretti Sep 12, 2023
dcb13d2
Fix ABA and RNEA with motor dynamics
flferretti Sep 12, 2023
16ff2b3
Refactor code and add oop decorators for JIT compilation
flferretti Sep 12, 2023
4d55ec5
Refactor motor dynamics test
flferretti Sep 12, 2023
58e2ca8
Correct typos in ABA and comment oop decorators
flferretti Sep 20, 2023
dd8e138
Solve deprecation warning
flferretti Oct 20, 2023
5fadf75
Update setters for motor parameters
flferretti Oct 20, 2023
da40016
Remove unnecessary assertion in crb
flferretti Oct 25, 2023
65f4046
Revert `rnea` modification
flferretti Oct 25, 2023
99a6e8e
Refactor and add viscous frictions in ABA
flferretti Oct 25, 2023
a171f95
Change has_motors logic
flferretti Oct 26, 2023
0bcc5dc
Avoid implicit jax.Array to bool conversion
flferretti Oct 26, 2023
4ff83ab
Restore RNEA with motors
flferretti Oct 26, 2023
6c1aab0
Update tests for motors
flferretti Oct 26, 2023
f03e4c8
Remove motor dynamics test
flferretti Nov 8, 2023
2bee17d
Fix motors for floating base
flferretti Nov 8, 2023
af1d537
Fix ABA for viscous friction
flferretti Nov 8, 2023
eb23ad8
Update ABA and RNEA
flferretti Nov 13, 2023
bb64a86
Fix CRB
flferretti Nov 22, 2023
7f4ced4
Fix CRB FD for fixed-based models
flferretti Dec 6, 2023
651a6b1
Enforce dtypes in motor parameters
flferretti Dec 6, 2023
f82e351
Put ABA and RNEA with motors in separate files
flferretti Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix ABA and RNEA with motor dynamics
flferretti committed Dec 6, 2023
commit dcb13d2d8a85b0b5197fd8c0862428eb93de1bcd
68 changes: 25 additions & 43 deletions src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
@@ -175,51 +175,33 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
U, m_U, d, u, m_u, MA, pA = carry

# Compute intermediate results
U_i = MA[i] @ S[i]
U = U.at[i].set(U_i)

d_i = S[i].T @ U[i] + S[i].T / Γ[i] @ m_U[i]
d = d.at[i].set(d_i.squeeze())

# Compute joint velocities
vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0

u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
u = u.at[i].set(u_i.squeeze())

# Add motor dynamics to the articulated-body inertia and bias forces
def add_motors(
carry: Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.VectorJax]
) -> Tuple[jtp.VectorJax, jtp.MatrixJax, jtp.MatrixJax, jtp.VectorJax]:
MA, pA, m_U, m_u = carry

m_U_i = I_m[i] * S[i] / Γ[i]
m_U = m_U.at[i].set(m_U_i)
m_u_i = (
tau[ii] * Γ[ii] - S[i].T @ pA[i] - K̅ᵥ[ii].T * qd[ii] / Γ[i]
if tau.size != 0
else -S[i].T @ pA[i] - K̅ᵥ[ii].T * qd[ii] / Γ[i]
)
m_u = m_u.at[i].set(m_u_i.squeeze())
# print(f"Adding {- m_U[i] / d[i] @ m_U[i].T} to MA")
MA_i = MA[i] - m_U[i] / d[i] @ m_U[i].T
MA = MA.at[i].set(MA_i)

pA_i = pA[i] + pR[i] + m_U[i] * m_u[i] / d[i] + I_m[i] * m_c[i]
pA = pA.at[i].set(pA_i)

return MA, pA, m_U, m_u

MA, pA, m_U, m_u = jax.lax.cond(
pred=model.has_motors,
true_fun=add_motors,
false_fun=lambda carry: carry,
operand=(MA, pA, m_U, m_u),
m_u_i = (
tau[ii] / Γ[i] - m_S[i].T @ pR[i] if tau.size != 0 else -m_S[i].T @ pR[i]
)
m_u = m_u.at[i].set(m_u_i.squeeze())

U_i = MA[i] @ S[i]
U = U.at[i].set(U_i)

m_U_i = IM[i] @ m_S[i]
m_U = m_U.at[i].set(m_U_i)

D_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i]
D = D.at[i].set(D_i.squeeze())

# Compute the articulated-body inertia and bias forces of this link
Ma = MA[i] - U[i] / d[i] @ U[i].T
pa = pA[i] + Ma @ c[i] + U[i] * u[i] / d[i]
Ma = MA[i] + IM[i] - U[i] / D[i] @ U[i].T - m_U[i] / D[i] @ m_U[i].T
pa = (
pA[i]
+ pR[i]
+ Ma[i] @ c[i]
+ IM[i] @ m_c[i]
+ U[i] / D[i] * u[i]
+ m_U[i] / D[i] * m_u[i]
)

# Propagate them to the parent, handling the base link
def propagate(
@@ -230,7 +212,7 @@ def propagate(
MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
MA = MA.at[λ[i]].set(MA_λi)

pA_λi = pA[λ[i]] + i_X_λi[i].T @ (pa + pR[i])
pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
pA = pA.at[λ[i]].set(pA_λi)

return MA, pA
@@ -268,13 +250,13 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
a, qdd = carry

# Propagate link accelerations
a_i = i_X_λi[i] @ a[λ[i]] + c[i]
a_i = i_X_λi[i] @ a[λ[i]]

# Compute joint accelerations
qdd_ii = (u[i] - U[i].T @ a_i) / d[i] + (m_u[i] - m_U[i].T @ a_i * Γ[i]) / d[i]
qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / D[i]
qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd

a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
a_i = a_i + S[i] * qdd[ii] + c[i] if qdd.size != 0 else a_i
a = a.at[i].set(a_i)

return (a, qdd), None
7 changes: 4 additions & 3 deletions src/jaxsim/physics/algos/rnea.py
Original file line number Diff line number Diff line change
@@ -121,20 +121,21 @@ def forward_pass(
i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry

vJ = S[i] * qd[ii]
vJ_m = S_m[i] * qd[ii]

i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)

vJ_m = S[i] * qd[ii] / Γ[i]
v_i_m = v_m[λ[i]] + vJ_m
v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
v_m = v_m.at[i].set(v_i_m)

a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
a = a.at[i].set(a_i)

a_i_m = S[i] * qdd[ii] / Γ[i] + Cross.vx(v_m[i]) @ vJ_m
a_i_m = i_X_λi[i] @ a_m[λ[i]] + S_m[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
a_m = a_m.at[i].set(a_i_m)

i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]