Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions src/bindings.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,10 @@ public:
PyQOCOSolution &get_solution();

QOCOInt update_settings(const QOCOSettings &);
// QOCOInt update_vector_data(py::object, py::object, py::object);
// QOCOInt update_matrix_data(py::object, py::object, py::object);
//QOCOInt update_vector_data(py::object, py::object, py::object);
//QOCOInt update_matrix_data(py::object, py::object, py::object);
QOCOInt update_vector_data(py::object cnew, py::object bnew, py::object hnew);
QOCOInt update_matrix_data(py::object Pxnew, py::object Axnew, py::object Gxnew);

QOCOInt solve();

Expand Down Expand Up @@ -241,6 +243,78 @@ QOCOInt PyQOCOSolver::update_settings(const QOCOSettings &new_settings)
return qoco_update_settings(this->_solver, &new_settings);
}

QOCOInt PyQOCOSolver::update_vector_data(py::object cnew, py::object bnew, py::object hnew)
{
QOCOFloat *cnew_ptr = nullptr;
QOCOFloat *bnew_ptr = nullptr;
QOCOFloat *hnew_ptr = nullptr;

if (cnew != py::none())
{
auto cnew_arr = cnew.cast<py::array_t<QOCOFloat>>();
auto buf = cnew_arr.request();
if (buf.shape[0] != this->n)
throw std::runtime_error("cnew size must be n = " + std::to_string(this->n));
cnew_ptr = (QOCOFloat *)buf.ptr;
}

if (bnew != py::none())
{
auto bnew_arr = bnew.cast<py::array_t<QOCOFloat>>();
auto buf = bnew_arr.request();
if (buf.shape[0] != this->p)
throw std::runtime_error("bnew size must be p = " + std::to_string(this->p));
bnew_ptr = (QOCOFloat *)buf.ptr;
}

if (hnew != py::none())
{
auto hnew_arr = hnew.cast<py::array_t<QOCOFloat>>();
auto buf = hnew_arr.request();
if (buf.shape[0] != this->m)
throw std::runtime_error("hnew size must be m = " + std::to_string(this->m));
hnew_ptr = (QOCOFloat *)buf.ptr;
}

return qoco_update_vector_data(this->_solver, cnew_ptr, bnew_ptr, hnew_ptr);
}

QOCOInt PyQOCOSolver::update_matrix_data(py::object Pxnew, py::object Axnew, py::object Gxnew)
{
QOCOFloat *Pxnew_ptr = nullptr;
QOCOFloat *Axnew_ptr = nullptr;
QOCOFloat *Gxnew_ptr = nullptr;

if (Pxnew != py::none())
{
auto Pxnew_arr = Pxnew.cast<py::array_t<QOCOFloat>>();
auto buf = Pxnew_arr.request();
if (buf.ndim != 1)
throw std::runtime_error("Pxnew must be 1-D array");
Pxnew_ptr = (QOCOFloat *)buf.ptr;
}

if (Axnew != py::none())
{
auto Axnew_arr = Axnew.cast<py::array_t<QOCOFloat>>();
auto buf = Axnew_arr.request();
if (buf.ndim != 1)
throw std::runtime_error("Axnew must be 1-D array");
Axnew_ptr = (QOCOFloat *)buf.ptr;
}

if (Gxnew != py::none())
{
auto Gxnew_arr = Gxnew.cast<py::array_t<QOCOFloat>>();
auto buf = Gxnew_arr.request();
if (buf.ndim != 1)
throw std::runtime_error("Gxnew must be 1-D array");
Gxnew_ptr = (QOCOFloat *)buf.ptr;
}

return qoco_update_matrix_data(this->_solver, Pxnew_ptr, Axnew_ptr, Gxnew_ptr);
}

PYBIND11_MODULE(@QOCO_EXT_MODULE_NAME@, m)
{
// Enums.
Expand Down Expand Up @@ -308,6 +382,8 @@ PYBIND11_MODULE(@QOCO_EXT_MODULE_NAME@, m)
.def(py::init<QOCOInt, QOCOInt, QOCOInt, const CSC &, const py::array_t<QOCOFloat>, const CSC &, const py::array_t<QOCOFloat>, const CSC &, const py::array_t<QOCOFloat>, QOCOInt, QOCOInt, const py::array_t<QOCOInt>, QOCOSettings *>(), "n"_a, "m"_a, "p"_a, "P"_a, "c"_a.noconvert(), "A"_a, "b"_a.noconvert(), "G"_a, "h"_a.noconvert(), "l"_a, "nsoc"_a, "q"_a.noconvert(), "settings"_a)
.def_property_readonly("solution", &PyQOCOSolver::get_solution, py::return_value_policy::reference)
.def("update_settings", &PyQOCOSolver::update_settings)
.def("update_vector_data", &PyQOCOSolver::update_vector_data, "cnew"_a=py::none(), "bnew"_a=py::none(), "hnew"_a=py::none())
.def("update_matrix_data", &PyQOCOSolver::update_matrix_data, "Pxnew"_a=py::none(), "Axnew"_a=py::none(), "Gxnew"_a=py::none())
.def("solve", &PyQOCOSolver::solve)
.def("get_settings", &PyQOCOSolver::get_settings, py::return_value_policy::reference);
}
95 changes: 95 additions & 0 deletions src/qoco/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,101 @@ def update_settings(self, **kwargs):
if settings_changed and self._solver is not None:
self._solver.update_settings(self.settings)

def update_vector_data(self, c=None, b=None, h=None):
"""
Update data vectors.

Parameters
----------
c : np.ndarray, optional
New c vector of size n. If None, c is not updated. Default is None.
b : np.ndarray, optional
New b vector of size p. If None, b is not updated. Default is None.
h : np.ndarray, optional
New h vector of size m. If None, h is not updated. Default is None.

Returns
-------
int
Status code from the solver
"""
cnew_ptr = None
bnew_ptr = None
hnew_ptr = None

if c is not None:
if not isinstance(c, np.ndarray):
c = np.array(c)
c = c.astype(np.float64)
if c.shape[0] != self.n:
raise ValueError(f"c size must be n = {self.n}")
cnew_ptr = c

if b is not None:
if not isinstance(b, np.ndarray):
b = np.array(b)
b = b.astype(np.float64)
if b.shape[0] != self.p:
raise ValueError(f"b size must be p = {self.p}")
bnew_ptr = b

if h is not None:
if not isinstance(h, np.ndarray):
h = np.array(h)
h = h.astype(np.float64)
if h.shape[0] != self.m:
raise ValueError(f"h size must be m = {self.m}")
hnew_ptr = h

return self._solver.update_vector_data(cnew_ptr, bnew_ptr, hnew_ptr)

def update_matrix_data(self, P=None, A=None, G=None):
"""
Update sparse matrix data.

The new matrices must have the same sparsity structure as the original ones.

Parameters
----------
P : np.ndarray, optional
New data for P matrix (only the nonzero values). If None, P is not updated.
Default is None.
A : np.ndarray, optional
New data for A matrix (only the nonzero values). If None, A is not updated.
Default is None.
G : np.ndarray, optional
New data for G matrix (only the nonzero values). If None, G is not updated.
Default is None.

Returns
-------
int
Status code from the solver
"""
Pxnew_ptr = None
Axnew_ptr = None
Gxnew_ptr = None

if P is not None:
if not isinstance(P, np.ndarray):
P = np.array(P)
P = P.astype(np.float64)
Pxnew_ptr = P

if A is not None:
if not isinstance(A, np.ndarray):
A = np.array(A)
A = A.astype(np.float64)
Axnew_ptr = A

if G is not None:
if not isinstance(G, np.ndarray):
G = np.array(G)
G = G.astype(np.float64)
Gxnew_ptr = G

return self._solver.update_matrix_data(Pxnew_ptr, Axnew_ptr, Gxnew_ptr)

def setup(self, n, m, p, P, c, A, b, G, h, l, nsoc, q, **settings):
self.m = m
self.n = n
Expand Down