aboutsummaryrefslogtreecommitdiff
path: root/bloch_messiah.m
diff options
context:
space:
mode:
Diffstat (limited to 'bloch_messiah.m')
-rw-r--r--bloch_messiah.m85
1 files changed, 85 insertions, 0 deletions
diff --git a/bloch_messiah.m b/bloch_messiah.m
new file mode 100644
index 0000000..6724a26
--- /dev/null
+++ b/bloch_messiah.m
@@ -0,0 +1,85 @@
+function [ut1, st1, v1] = bloch_messiah(S, tol, rounding)
+ % Bloch-Messiah decomposition of a symplectic matrix.
+ %
+ % Args:
+ % S (matrix): symplectic matrix
+ % tol (double): tolerance for symplectic check (default: 1e-10)
+ % rounding (int): decimal places for rounding singular values (default: 9)
+ %
+ % Returns:
+ % ut1, st1, v1 (matrices): Decomposition matrices such that S = ut1 * st1 * v1
+
+ if nargin < 2
+ tol = 1e-10;
+ end
+ if nargin < 3
+ rounding = 9;
+ end
+
+ [n, m] = size(S);
+
+ if n ~= m
+ error('The input matrix is not square');
+ end
+
+ if mod(n, 2) ~= 0
+ error('The input matrix must have an even number of rows/columns');
+ end
+
+ n = n / 2;
+ omega = sympmat(n);
+
+ if norm(S' * omega * S - omega) >= tol
+ error('The input matrix is not symplectic');
+ end
+
+ if norm(S' * S - eye(2*n)) >= tol
+ [u, sigma] = polardecomp(S, 'left');
+ [ss, uss] = takagi(sigma, tol, rounding);
+
+ % Apply permutation matrix
+ perm = [1:n, 2*n:-1:n+1];
+ pmat = eye(2*n);
+ pmat = pmat(perm, :);
+
+ ut = uss * pmat;
+
+ % Apply second permutation matrix
+ qomega = ut' * omega * ut;
+ st = pmat * diag(ss) * pmat;
+
+ % Identify degenerate subspaces
+ st_diag = round(diag(st), rounding);
+ [~, ~, ic] = unique(st_diag(1:n));
+ stop_is = cumsum(accumarray(ic, 1));
+ start_is = [0; stop_is(1:end-1)] + 1;
+
+ % Rotation matrices based on SVD
+ u_list = cell(1, length(start_is));
+ v_list = cell(1, length(start_is));
+
+ for i = 1:length(start_is)
+ start_i = start_is(i);
+ stop_i = stop_is(i);
+ x = real(qomega(start_i:stop_i, n+start_i:n+stop_i));
+ [u_svd, ~, v_svd] = svd(x);
+ u_list{i} = u_svd;
+ v_list{i} = v_svd';
+ end
+
+ pmat1 = blkdiag(u_list{:}, v_list{:});
+
+ st1 = pmat1' * pmat * diag(ss) * pmat * pmat1;
+ ut1 = uss * pmat * pmat1;
+ v1 = ut1' * u;
+ else
+ ut1 = S;
+ st1 = eye(2*n);
+ v1 = eye(2*n);
+ end
+
+ ut1 = real(ut1);
+ st1 = real(st1);
+ v1 = real(v1);
+end
+