8000 Fix type hints in chain_mass solution sensitivity example by dirkpr · Pull Request #1452 · acados/acados · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
/ acados Public

Fix type hints in chain_mass solution sensitivity example #1452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions examples/acados_python/chain_mass/solution_sensitivity_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import casadi as ca
from casadi import SX, norm_2, vertcat
from casadi.tools import struct_symSX, entry
from casadi.tools.structure3 import DMStruct
from casadi.tools.structure3 import DMStruct, ssymStruct
import matplotlib.pyplot as plt
from acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
from utils import get_chain_params
Expand All @@ -47,7 +47,7 @@
import time


def export_discrete_erk4_integrator_step(f_expl: SX, x: SX, u: SX, p: struct_symSX, h: float, n_stages: int = 2) -> ca.SX:
def export_discrete_erk4_integrator_step(f_expl: SX, x: SX, u: SX, p: ssymStruct, h: float, n_stages: int = 2) -> ca.SX:
"""Define ERK4 integrator for continuous dynamics."""
dt = h / n_stages
ode = ca.Function("f", [x, u, p], [f_expl])
Expand All @@ -62,7 +62,7 @@ def export_discrete_erk4_integrator_step(f_expl: SX, x: SX, u: SX, p: struct_sym
return xnext


def define_param_struct_symSX(n_mass: int, disturbance: bool = True) -> struct_symSX:
def define_param_ssymStruct(n_mass: int, disturbance: bool = True) -> ssymStruct:
"""Define parameter struct."""
n_link = n_mass - 1

Expand Down Expand Up @@ -111,7 +111,7 @@ def export_chain_mass_model(n_mass: int, Ts: float = 0.2, disturbance: bool = Fa
xdot = SX.sym("xdot", nx, 1)

f = SX.zeros(3 * M, 1) # force on intermediate masses
p = define_param_struct_symSX(n_mass=n_mass, disturbance=disturbance)
p = define_param_ssymStruct(n_mass=n_mass, disturbance=disturbance)

# Gravity force
for i in range(M):
Expand Down Expand Up @@ -189,10 +189,9 @@ def export_chain_mass_model(n_mass: int, Ts: float = 0.2, disturbance: bool = Fa


def compute_parametric_steady_state(
model: AcadosModel, p: struct_symSX, xPosFirstMass: np.ndarray, xEndRef: np.ndarray
model: AcadosModel, p: DMStruct, xPosFirstMass: np.ndarray, xEndRef: np.ndarray
) -> np.ndarray:
"""Compute steady state for chain mass model."""
# TODO reuse/adapt the compute_steady_state function in utils.py

p_ = p(0)
p_["m"] = p["m"]
Expand Down Expand Up @@ -301,13 +300,13 @@ def export_parametric_ocp(
x_e = ocp.model.x - x_ss
u_e = ocp.model.u - np.zeros((nu, 1))

idx = find_idx_for_labels(define_param_struct_symSX(chain_params_["n_mass"], disturbance=True).cat, "Q")
idx = find_idx_for_labels(define_param_ssymStruct(chain_params_["n_mass"], disturbance=True).cat, "Q")
Q_sym = ca.reshape(ocp.model.p_global[idx], (nx, nx))
q_diag = np.ones((nx, 1))
q_diag[3 * M : 3 * M + 3] = M + 1
p["Q"] = 2 * np.diagflat(q_diag)

idx = find_idx_for_labels(define_param_struct_symSX(chain_params_["n_mass"], disturbance=True).cat, "R")
idx = find_idx_for_labels(define_param_ssymStruct(chain_params_["n_mass"], disturbance=True).cat, "R")
R_sym = ca.reshape(ocp.model.p_global[idx], (nu, nu))
p["R"] = 2 * np.diagflat(1e-2 * np.ones((nu, 1)))

Expand Down Expand Up @@ -339,11 +338,6 @@ def export_parametric_ocp(
ocp.solver_options.qp_solver_ric_alg = qp_solver_ric_alg
ocp.solver_options.qp_solver_cond_N = ocp.solver_options.N_horizon
ocp.solver_options.with_solution_sens_wrt_params = True
# # Old settings needed, when calling solve() instead of setup_qp_matrices_and_factorize()
# ocp.solver_options.globalization_fixed_step_length = 0.0
# ocp.solver_options.nlp_solver_max_iter = 1
# ocp.solver_options.qp_solver_iter_max = 200
# ocp.solver_options.tol = 1e-10
else:
ocp.solver_options.nlp_solver_max_iter = nlp_iter
ocp.solver_options.qp_solver_cond_N = ocp.solver_options.N_horizon
Expand Down Expand Up @@ -399,7 +393,7 @@ def main_parametric(qp_solver_ric_alg: int = 0, chain_params_: dict = get_chain_
# p_label = "D_2_0"
p_label = f"C_{M}_0"

p_idx = find_idx_for_labels(define_param_struct_symSX(chain_params_["n_mass"], disturbance=True).cat, p_label)[0]
p_idx = find_idx_for_labels(define_param_ssymStruct(chain_params_["n_mass"], disturbance=True).cat, p_label)[0]

p_var = np.linspace(0.5 * parameter_values.cat[p_idx], 1.5 * parameter_values.cat[p_idx], np_test).flatten()

Expand Down
Loading
0