Source code for mvpure_py.localizer.localizer_utils

""" Functions used during sources localization. """

# Authors: Julia Jurkowska, Tomasz Piotrowski

import os
from typing import List, Optional, Tuple

import numpy as np
import scipy.linalg as sla
from joblib import Parallel, delayed
from scipy.linalg import cho_factor, cho_solve, pinv
from tqdm import tqdm

from ..utils import algebra
from ..viz import plot_RN_eigenvalues

# Number of threads used by BLAS/LAPACK - ADJUSTABLE
os.environ["OMP_NUM_THREADS"] = "10"


# ============================================================
# ---------------- Helper: build batch blocks ----------------
# ============================================================

def _build_blocks_batch(H_sel: np.ndarray,
                        M_sel: Optional[np.ndarray],
                        A_M_batch: np.ndarray,
                        H_batch: np.ndarray) -> np.ndarray:
    """
    Construct block matrices for a batch of candidate sources.

    Parameters
    ----------
    H_sel : np.ndarray
        Already selected columns of H.
    M_sel : np.ndarray or None
        Current accumulated block matrix (None if empty).
    A_M_batch : np.ndarray
        Batch of transformed candidate sources (e.g., R⁻¹ H).
    H_batch : np.ndarray
        Corresponding candidate columns from H.

    Returns
    -------
    blocks : np.ndarray
        Batch of block matrices, shape (B, m, m).
    """
    k = 0 if M_sel is None else M_sel.shape[0]   # size of current block
    B = A_M_batch.shape[1]                       # number of candidates in the batch
    m = k + 1                                    # new block size
    dtype = A_M_batch.dtype

    if k == 0:
        # First iteration: build scalar blocks (1x1 matrices)
        s_batch = np.sum(A_M_batch * H_batch, axis=0)
        return s_batch.reshape(B, 1, 1).astype(dtype, copy=False)

    # Cross-terms with already selected sources
    a_batch = H_sel[:, :k].T @ A_M_batch
    s_batch = np.sum(A_M_batch * H_batch, axis=0)

    # Assemble full block matrices for the batch
    blocks = np.empty((B, m, m), dtype=dtype)
    blocks[:, :k, :k] = M_sel[None, :, :]             # top-left block (previous selection)
    blocks[:, :k, k:k+1] = a_batch.T.reshape(B, k, 1) # last column
    blocks[:, k:k+1, :k] = a_batch.T.reshape(B, 1, k) # last row
    blocks[:, k, k] = s_batch                         # bottom-right element

    return blocks


# ============================================================
# ---------------- Activity index: simplified ----------------
# ============================================================

def _simple_mai_batch(blocks_G: np.ndarray,
                      blocks_S: np.ndarray,
                      current_iter: int) -> np.ndarray:
    """
    Compute the MAI activity index for a batch (simplified form) in parallel.

    Definition:
        MAI = trace(G S⁻¹) - current_iter
    """
    B, m, _ = blocks_S.shape

    # Handle 1x1 blocks quickly
    if m == 1:
        return blocks_G[:, 0, 0] / blocks_S[:, 0, 0] - current_iter

    I = np.eye(m)

    def _process_block(i):
        # Inverse of S
        try:
            invS = cho_solve(cho_factor(blocks_S[i], lower=False, check_finite=False),
                             I, check_finite=False)
        except np.linalg.LinAlgError:
            invS = pinv(blocks_S[i])

        # Activity index for this block
        return np.trace(blocks_G[i] @ invS) - current_iter

    # Parallel execution across all blocks
    results = Parallel(n_jobs=-1)(delayed(_process_block)(i) for i in range(B))
    return np.array(results, dtype=float)


def _simple_mpz_batch(blocks_S: np.ndarray,
                      blocks_T: np.ndarray,
                      current_iter: int) -> np.ndarray:
    """
    Compute the MPZ activity index for a batch (simplified form) in parallel.

    Definition:
        MPZ = trace(S T⁻¹) - current_iter
    """
    B, m, _ = blocks_S.shape

    # Handle 1x1 blocks quickly
    if m == 1:
        return blocks_S[:, 0, 0] / blocks_T[:, 0, 0] - current_iter

    I = np.eye(m)

    def _process_block(i):
        # Inverse of T
        try:
            invT = cho_solve(
                cho_factor(blocks_T[i], lower=False, check_finite=False),
                I, check_finite=False
            )
        except np.linalg.LinAlgError:
            invT = np.linalg.pinv(blocks_T[i])

        return np.trace(blocks_S[i] @ invT) - current_iter

    # Parallel execution across blocks
    results = Parallel(n_jobs=-1)(delayed(_process_block)(i) for i in range(B))
    return np.array(results, dtype=float)


# ============================================================
# ---------------- Activity index: full versions --------------
# ============================================================

def _mai_mvp_batch(blocks_G: np.ndarray,
                   blocks_S: np.ndarray,
                   r: int,
                   current_iter: int) -> np.ndarray:
    """
    Compute the MAI-MVP activity index for a batch in parallel.

    This version matches the eigenvalue-based definition:
        MAI-MVP = sum of top-r eigenvalues of (G S⁻¹) - r
    """
    B, m, _ = blocks_S.shape

    # For early iterations (<= r), fall back to simplified version
    if current_iter <= r:
        return _simple_mai_batch(blocks_G, blocks_S, current_iter)

    I = np.eye(m)

    def _process_block(i):
        if m == 1:
            return blocks_G[i, 0, 0] / blocks_S[i, 0, 0] - r

        # Inverse of S
        try:
            invS = cho_solve(cho_factor(blocks_S[i], lower=False, check_finite=False),
                             I, check_finite=False)
        except np.linalg.LinAlgError:
            invS = pinv(blocks_S[i])

        # Eigen-decomposition of G S⁻¹
        M = blocks_G[i] @ invS
        s, _ = np.linalg.eig(M)
        s_sorted = np.sort(s)[::-1].real  # descending order

        return np.sum(s_sorted[:r]) - r

    # Parallel execution across blocks (order preserved)
    results = Parallel(n_jobs=-1)(delayed(_process_block)(i) for i in range(B))
    return np.array(results, dtype=float)


def _mpz_mvp_batch(blocks_S: np.ndarray,
                   blocks_T: np.ndarray,
                   blocks_G: np.ndarray,
                   r: int,
                   current_iter: int) -> np.ndarray:
    """
    Compute the MPZ-MVP activity index for a batch.

    Exact CPU version with block projection:
        MPZ-MVP = trace(S * T⁻¹ P) - r
    where P is the oblique projection onto the top-r eigenspace of SQ.
    """
    B, m, _ = blocks_S.shape
    I = np.eye(m)

    if current_iter <= r:
        return _simple_mpz_batch(blocks_S, blocks_T, current_iter)

    def _process_block(i):
        # Inverse of S
        try:
            invS = cho_solve(cho_factor(blocks_S[i], lower=False, check_finite=False),
                             I, check_finite=False)
        except np.linalg.LinAlgError:
            invS = pinv(blocks_S[i])

        # Inverse of G
        try:
            invG = cho_solve(cho_factor(blocks_G[i], lower=False, check_finite=False),
                             I, check_finite=False)
        except np.linalg.LinAlgError:
            invG = pinv(blocks_G[i])

        # Q = S⁻¹ - G⁻¹
        Q = invS - invG

        # Eigen-decomposition of S Q
        s, u = np.linalg.eig(blocks_S[i] @ Q)
        sorted_idx = np.argsort(s)[::-1]  # sort eigenvalues descending
        u = u[:, sorted_idx]

        # Projection matrix onto top-r subspace
        proj_matrix = u @ np.block([
            [np.eye(r), np.zeros((r, current_iter - r))],
            [np.zeros((current_iter - r, current_iter))]
        ]) @ pinv(u)

        # Apply T⁻¹ to projection
        try:
            temp = cho_solve(cho_factor(blocks_T[i], lower=False, check_finite=False),
                             proj_matrix, check_finite=False)
        except np.linalg.LinAlgError:
            temp = pinv(blocks_T[i]) @ proj_matrix

        return np.trace(blocks_S[i] @ temp) - r

    # Parallel execution across all B blocks
    results = Parallel(n_jobs=-1)(delayed(_process_block)(i) for i in range(B))
    return np.array(results, dtype=float)

# ============================================================
# ---------------- Dispatcher: choose function ---------------
# ============================================================

def _choose_activity_index_batch(localizer_to_use: str):
    """
    Return the appropriate activity index function depending on the localizer name.
    """
    if localizer_to_use == "mai":
        return lambda Gs, Ss, Ts, current_iter, r: _simple_mai_batch(Gs, Ss, current_iter)
    elif localizer_to_use == "mpz":
        return lambda Gs, Ss, Ts, current_iter, r: _simple_mpz_batch(Ss, Ts, current_iter)
    elif localizer_to_use == "mai_mvp":
        return lambda Gs, Ss, Ts, current_iter, r: _mai_mvp_batch(Gs, Ss, r, current_iter)
    elif localizer_to_use == "mpz_mvp":
        return lambda Gs, Ss, Ts, current_iter, r: _mpz_mvp_batch(Ss, Ts, Gs, r, current_iter)
    else:
        raise ValueError("Allowed: 'mai', 'mpz', 'mai_mvp', 'mpz_mvp'")


# ============================================================
# ---------------- Main function -----------------------------
# ============================================================

[docs] def get_activity_index(localizer_to_use: str, H: np.ndarray, R: np.ndarray, N: np.ndarray, n_sources_to_localize: int, r: int, batch_size: int = 16384, show_progress: bool = True ) -> Tuple[List[int], np.ndarray, int, np.ndarray]: """ Greedy source localization algorithm. Parameters ---------- localizer_to_use : str Which localizer to use: 'mai', 'mpz', 'mai_mvp', 'mpz_mvp'. H : np.ndarray Leadfield matrix (channels x sources). R : np.ndarray Measurement covariance matrix. N : np.ndarray Noise covariance matrix. n_sources_to_localize : int Number of sources to select. r : int Rank parameter. batch_size : int Number of candidates processed per batch. show_progress : bool Whether to show a progress bar. Returns ------- index_max : List[int] Indices of selected sources. act_values : np.ndarray Activity index values for each selected source. r : int Rank parameter (unchanged). H_res : np.ndarray Selected columns of H. """ # Precompute transformed matrices with Cholesky or fallback to pseudoinverse try: choR = sla.cho_factor(R, lower=False, check_finite=False) A_R = sla.cho_solve(choR, H, check_finite=False) except: A_R = np.linalg.pinv(R) @ H try: choN = sla.cho_factor(N, lower=False, check_finite=False) A_N = sla.cho_solve(choN, H, check_finite=False) except: A_N = np.linalg.pinv(N) @ H try: A_TR = sla.cho_solve(choR, N @ A_R, check_finite=False) except: A_TR = (np.linalg.pinv(R) @ N @ np.linalg.pinv(R)) @ H # Initialization n_channels, n_sources = H.shape H_sel = np.zeros((n_channels, n_sources_to_localize), dtype=H.dtype) S_sel = G_sel = T_sel = None index_max: List[int] = [] act_values: List[float] = [] func = _choose_activity_index_batch(localizer_to_use) selected_mask = np.zeros(n_sources, dtype=bool) all_indices = np.arange(n_sources, dtype=int) need_G = ("mai" in localizer_to_use) or (localizer_to_use == "mpz_mvp") need_T = ("mpz" in localizer_to_use) # Greedy selection loop for outer_iter in range(n_sources_to_localize): best_val = -np.inf best_idx = None if show_progress: pbar = tqdm(total=n_sources, desc=f"iter {outer_iter+1}/{n_sources_to_localize}") # Process candidates in batches for start in range(0, n_sources, batch_size): end = min(n_sources, start + batch_size) batch_idx = all_indices[start:end] valid = ~selected_mask[batch_idx] if not np.any(valid): if show_progress: pbar.update(end - start) continue batch_idx = batch_idx[valid] # Build block structures A_R_batch = A_R[:, batch_idx] H_batch = H[:, batch_idx] blocks_S = _build_blocks_batch(H_sel[:, :len(index_max)], S_sel, A_R_batch, H_batch) if need_G: A_N_batch = A_N[:, batch_idx] blocks_G = _build_blocks_batch(H_sel[:, :len(index_max)], G_sel, A_N_batch, H_batch) else: blocks_G = np.zeros_like(blocks_S) if need_T: A_TR_batch = A_TR[:, batch_idx] blocks_T = _build_blocks_batch(H_sel[:, :len(index_max)], T_sel, A_TR_batch, H_batch) else: blocks_T = np.zeros_like(blocks_S) # Compute activity index values for the batch vals = func(blocks_G, blocks_S, blocks_T, current_iter=outer_iter + 1, r=r) local_best_idx = int(np.argmax(vals)) local_best_val = float(vals[local_best_idx]) # Update global best if local_best_val > best_val: best_val = local_best_val best_idx = int(batch_idx[local_best_idx]) if show_progress: pbar.update(end - start) if show_progress: pbar.close() if best_idx is None: break # Save chosen index and its value index_max.append(best_idx) act_values.append(float(best_val)) sel_pos = len(index_max) - 1 H_sel[:, sel_pos:sel_pos+1] = H[:, best_idx:best_idx+1] selected_mask[best_idx] = True # Incrementally update block matrices S_sel / G_sel / T_sel aR_chosen = A_R[:, best_idx:best_idx+1] S_sel = _build_blocks_batch(H_sel[:, :sel_pos+1], S_sel, aR_chosen, H[:, best_idx:best_idx+1])[0] if need_G: aN_chosen = A_N[:, best_idx:best_idx+1] G_sel = _build_blocks_batch(H_sel[:, :sel_pos+1], G_sel, aN_chosen, H[:, best_idx:best_idx+1])[0] if need_T: aTR_chosen = A_TR[:, best_idx:best_idx+1] T_sel = _build_blocks_batch(H_sel[:, :sel_pos+1], T_sel, aTR_chosen, H[:, best_idx:best_idx+1])[0] # Collect selected sources H_res = H[:, index_max].copy() if index_max else np.zeros((n_channels, 0), dtype=H.dtype) # Final summary print("\n[Activity Index Result]") print(f" Selected indices (index_max): {index_max}") print(f" Rank parameter (r): {r}") return index_max, np.array(act_values, dtype=float), r, H_res
[docs] def suggest_n_sources_and_rank(R: np.ndarray, N: np.ndarray, show_plot: bool = True, subject: str = None, n_sources_threshold: float = 1, rank_threshold: float = 1.5, **kwargs) -> tuple[int, int]: """ Automatically propose number of sources to localize and rank based on Proposition 3 in [1]_. Parameters ---------- R : array-like Data covariance matrix N : array-like Noise covariance matrix show_plot : bool Whether to display a graph of the eigenvalues of the :math:`RN^{-1}` matrix. Default to True. subject : str Subject name the analysis is performed for. Optional. n_sources_threshold : float Number of eigenvalues of the :math:`RN^{-1}` matrix below this threshold corresponds to the suggested number of sources to localize. Default to 1.0. For more details see Observation 1 in [1]_. rank_threshold : float Number of eigenvalues of the :math:`RN^{-1}` matrix below this threshold corresponds to the suggested rank optimization parameter. Default to 1.5. For more details see Proposition 3 in [1]_. Returns ------- n_sources : int Suggested number of sources to localize. rank : int Suggested rank optimization parameter. References ---------- """ if show_plot: _, eigvals = plot_RN_eigenvalues(R=R, N=N, subject=subject, return_eigvals=True, **kwargs) else: eigvals = algebra.get_pinv_RN_eigenvals(R=R, N=N) # Suggesting number of sources n_sources_temp = np.where(eigvals > n_sources_threshold)[0] if n_sources_temp.size == 0: raise ValueError( f"All eigenvalues of $\\mathrm{{RN}}^{{-1}}$ are smaller than {n_sources_threshold}." ) else: n_sources = n_sources_temp[-1] + 1 # Suggesting rank rank_temp = np.where(eigvals > rank_threshold)[0] if rank_temp.size == 0: raise ValueError( f"All eigenvalues of $\\mathrm{{RN}}^{{-1}}$ are smaller than {rank_threshold}." ) else: rank = rank_temp[-1] + 1 print(f"Suggested number of sources to localize: {n_sources}") print(f"Suggested rank is: {rank}") return int(n_sources), int(rank)