aboutsummaryrefslogtreecommitdiff
path: root/bloch_messiah.m
blob: 4050102e417834a98ccadd4582cc7a57c6f8201f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
function [ut1, st1, v1] = bloch_messiah(S, tol)
    % Bloch-Messiah decomposition of a symplectic matrix.
    %
    % Args:
    %   S (matrix): symplectic matrix
    %   tol (double): tolerance for symplectic check (default: 1e-10)
    %
    % Returns:
    %   ut1, st1, v1 (matrices): Decomposition matrices such that S = ut1 * st1 * v1

    if nargin < 2
        tol = 1e-10;
    end

    [n, m] = size(S);

    if n ~= m
        error('The input matrix is not square');
    end

    if mod(n, 2) ~= 0
        error('The input matrix must have an even number of rows/columns');
    end

    n = n / 2;
    omega = sympmat(n);

    if norm(S' * omega * S - omega) >= tol
        error('The input matrix is not symplectic');
    end

    if norm(S' * S - eye(2*n)) >= tol
        [u, sigma] = polardecomp(S, 'left');
        [ss, uss] = takagi(sigma, tol);

        % Apply permutation matrix
        perm = [1:n, 2*n:-1:n+1];
        pmat = eye(2*n);
        pmat = pmat(perm, :);

        ut = uss * pmat;

        % Apply second permutation matrix
        qomega = ut' * omega * ut;
        st = pmat * diag(ss) * pmat;

        % Identify degenerate subspaces
        st_diag = diag(st);
        [~, ~, ic] = unique(st_diag(1:n));
        stop_is = cumsum(accumarray(ic, 1));
        start_is = [0; stop_is(1:end-1)] + 1;

        % Rotation matrices based on SVD
        u_list = cell(1, length(start_is));
        v_list = cell(1, length(start_is));

        for i = 1:length(start_is)
            start_i = start_is(i);
            stop_i = stop_is(i);
            x = real(qomega(start_i:stop_i, n+start_i:n+stop_i));
            [u_svd, ~, v_svd] = svd(x);
            u_list{i} = u_svd;
            v_list{i} = v_svd';
        end

        pmat1 = blkdiag(u_list{:}, v_list{:});

        st1 = pmat1' * pmat * diag(ss) * pmat * pmat1;
        ut1 = uss * pmat * pmat1;
        v1 = ut1' * u;
    else
        ut1 = S;
        st1 = eye(2*n);
        v1 = eye(2*n);
    end

    ut1 = real(ut1);
    st1 = real(st1);
    v1 = real(v1);
end