summaryrefslogtreecommitdiff
path: root/python_src/williamson.py
diff options
context:
space:
mode:
Diffstat (limited to 'python_src/williamson.py')
-rw-r--r--python_src/williamson.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/python_src/williamson.py b/python_src/williamson.py
new file mode 100644
index 0000000..19a2f5f
--- /dev/null
+++ b/python_src/williamson.py
@@ -0,0 +1,70 @@
+def williamson(V, tol=1e-11):
+ r"""Williamson decomposition of positive-definite (real) symmetric matrix.
+
+ See :ref:`williamson`.
+
+ Note that it is assumed that the symplectic form is
+
+ .. math:: \Omega = \begin{bmatrix}0&I\\-I&0\end{bmatrix}
+
+ where :math:`I` is the identity matrix and :math:`0` is the zero matrix.
+
+ See https://math.stackexchange.com/questions/1171842/finding-the-symplectic-matrix-in-williamsons-theorem/2682630#2682630
+
+ Args:
+ V (array[float]): positive definite symmetric (real) matrix
+ tol (float): the tolerance used when checking if the matrix is symmetric: :math:`|V-V^T| \leq` tol
+
+ Returns:
+ tuple[array,array]: ``(Db, S)`` where ``Db`` is a diagonal matrix
+ and ``S`` is a symplectic matrix such that :math:`V = S^T Db S`
+ """
+ (n, m) = V.shape
+
+ if n != m:
+ raise ValueError("The input matrix is not square")
+
+ diffn = np.linalg.norm(V - np.transpose(V))
+
+ if diffn >= tol:
+ raise ValueError("The input matrix is not symmetric")
+
+ if n % 2 != 0:
+ raise ValueError("The input matrix must have an even number of rows/columns")
+
+ n = n // 2
+ omega = sympmat(n)
+ vals = np.linalg.eigvalsh(V)
+
+ for val in vals:
+ if val <= 0:
+ raise ValueError("Input matrix is not positive definite")
+
+ Mm12 = sqrtm(np.linalg.inv(V)).real
+ r1 = Mm12 @ omega @ Mm12
+ s1, K = schur(r1)
+ X = np.array([[0, 1], [1, 0]])
+ I = np.identity(2)
+ seq = []
+
+ # In what follows I construct a permutation matrix p so that the Schur matrix has
+ # only positive elements above the diagonal
+ # Also the Schur matrix uses the x_1,p_1, ..., x_n,p_n ordering thus I use rotmat to
+ # go to the ordering x_1, ..., x_n, p_1, ... , p_n
+
+ for i in range(n):
+ if s1[2 * i, 2 * i + 1] > 0:
+ seq.append(I)
+ else:
+ seq.append(X)
+
+ p = block_diag(*seq)
+ Kt = K @ p
+ s1t = p @ s1 @ p
+ dd = xpxp_to_xxpp(s1t)
+ perm_indices = xpxp_to_xxpp(np.arange(2 * n))
+ Ktt = Kt[:, perm_indices]
+ Db = np.diag([1 / dd[i, i + n] for i in range(n)] + [1 / dd[i, i + n] for i in range(n)])
+ S = Mm12 @ Ktt @ sqrtm(Db)
+ return Db, np.linalg.inv(S).T
+