From f0ee34b3bc3365f9c9c2e2412b4aabfb65e20930 Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 19 Feb 2026 18:22:06 +0100 Subject: [PATCH 1/2] Calc r directly --- src/skillmodels/qr.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/skillmodels/qr.py b/src/skillmodels/qr.py index c690eac7..965b9cc3 100644 --- a/src/skillmodels/qr.py +++ b/src/skillmodels/qr.py @@ -19,16 +19,13 @@ def _householder(r: jax.Array, tau: jax.Array): """ m = r.shape[0] n = tau.shape[0] + r = jnp.tril(jnp.fill_diagonal(r, 1, inplace=False)) # Calculate Householder Vector which is saved in the lower triangle of R v1 = jnp.expand_dims(r[:, 0], 1) - v1 = v1.at[0:0].set(0) - v1 = v1.at[0].set(1) h = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) # Multiply all Householder Vectors Q = H(1)*H(2)...*H(n) for i in range(1, n): vi = jnp.expand_dims(r[:, i], 1) - vi = vi.at[0:i].set(0) - vi = vi.at[i].set(1) h = h - tau[i] * (h @ vi) @ jnp.transpose(vi) return h[:, :n] From 4cfdc9231480c4ebcbb876fa13c037f90d2247f4 Mon Sep 17 00:00:00 2001 From: mj023 Date: Wed, 25 Feb 2026 12:17:01 +0100 Subject: [PATCH 2/2] Apply Householder directly to R_inv --- src/skillmodels/kalman_filters.py | 4 ++-- src/skillmodels/qr.py | 29 +++++++++++++++++++++++------ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/skillmodels/kalman_filters.py b/src/skillmodels/kalman_filters.py index f9cfae97..0ab548a2 100644 --- a/src/skillmodels/kalman_filters.py +++ b/src/skillmodels/kalman_filters.py @@ -84,7 +84,7 @@ def kalman_update( _m = _m.at[..., 1:, :1].set(_f_stars) _m = _m.at[..., 1:, 1:].set(upper_chols) - _r = array_qr_jax(_m)[1] + _r = array_qr_jax(_m) _new_upper_chols = _r[..., 1:, 1:] _root_sigmas = _r[..., 0, 0] @@ -223,7 +223,7 @@ def kalman_predict( qr_points = jnp.zeros((n_obs, n_mixtures, n_sigma + n_fac, n_fac)) qr_points = qr_points.at[:, :, 0:n_sigma].set(devs * qr_weights) qr_points = qr_points.at[:, :, n_sigma:].set(jnp.diag(shock_sds)) - predicted_covs = array_qr_jax(qr_points)[1][:, :, :n_fac] + predicted_covs = array_qr_jax(qr_points)[:, :, :n_fac] return predicted_states, predicted_covs diff --git a/src/skillmodels/qr.py b/src/skillmodels/qr.py index 965b9cc3..4a5dfedb 100644 --- a/src/skillmodels/qr.py +++ b/src/skillmodels/qr.py @@ -7,8 +7,7 @@ def qr_gpu(a: jax.Array): """Custom implementation of the QR Decomposition.""" r, tau = jnp.linalg.qr(a, mode="raw") - q = _householder(r.mT, tau) - return q, jnp.triu(r.mT[: tau.shape[0]]) + return jnp.triu(r.mT[: tau.shape[0]]) def _householder(r: jax.Array, tau: jax.Array): @@ -30,6 +29,24 @@ def _householder(r: jax.Array, tau: jax.Array): return h[:, :n] +def _apply_householder_t(r: jax.Array, tau: jax.Array, a: jax.Array): + """Custom implementation of the Householder Product. + + Uses the outputs of jnp.linalg.qr with mode = "raw" to calculate Q. This is needed + because the JAX implementation is extremely slow for a batch of small matrices. + """ + n = tau.shape[0] + r = jnp.tril(jnp.fill_diagonal(r, 1, inplace=False)) + # Calculate Householder Vector which is saved in the lower triangle of R + v1 = jnp.expand_dims(r[:, n - 1], 1) + h = a - tau[n - 1] * a @ v1 @ jnp.transpose(v1) + # Multiply all Householder Vectors Q = H(1)*H(2)...*H(n) + for i in range(n - 2, -1, -1): + vi = jnp.expand_dims(r[:, i], 1) + h = h - tau[i] * (h @ vi) @ jnp.transpose(vi) + return h[:, :n] + + def _t(x: jax.Array) -> jax.Array: """Transpose batched Matrix.""" return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) @@ -53,15 +70,15 @@ def qr_jvp_rule(primals, tangents): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. (x,) = primals (dx,) = tangents - q, r = qr_gpu(x) + r_raw, tau = jnp.linalg.qr(x, mode="raw") + r = jnp.triu(r_raw.mT[: tau.shape[0]]) dx_rinv = jax.lax.linalg.triangular_solve(r, dx) # Right side solve by default - qt_dx_rinv = _h(q) @ dx_rinv + qt_dx_rinv = _apply_householder_t(r_raw.mT, tau, dx_rinv) qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _h(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs n = x.shape[-1] i = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2)) do = do + i * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) - dq = q @ (do - qt_dx_rinv) + dx_rinv dr = (qt_dx_rinv - do) @ r - return (q, r), (dq, dr) + return (r), (dr)