""" Functions used during sources localization. """
# Authors: Julia Jurkowska, Tomasz Piotrowski
import os
from typing import List, Optional, Tuple
import numpy as np
from joblib import Parallel, delayed
from scipy.linalg import cho_factor, cho_solve, pinv
from tqdm import tqdm
# Number of threads used by BLAS/LAPACK - ADJUSTABLE
os.environ["OMP_NUM_THREADS"] = "10"
# ============================================================
# -------- PSD low-rank pseudo-inverse (apply-only) ----------
# ============================================================
[docs]
class PSDPinvApply:
"""
Robust pseudo-inverse operator for PSD (possibly low-rank) matrix A:
A^+ ≈ U diag(1/lam) U^T (truncated eigenpairs)
Use: op.apply(X) = A^+ @ X
"""
__slots__ = ("U", "lam", "cutoff")
def __init__(self, U: np.ndarray, lam: np.ndarray, cutoff: float):
self.U = U
self.lam = lam
self.cutoff = cutoff
[docs]
def apply(self, X: np.ndarray) -> np.ndarray:
if self.U.size == 0:
return np.zeros_like(X)
UtX = self.U.T @ X
return self.U @ (UtX / self.lam[:, None])
[docs]
def psd_pinv_operator(A: np.ndarray,
rmax: Optional[int] = None,
eps_rel: float = 1e-10) -> PSDPinvApply:
"""
Build a robust pseudo-inverse *operator* for PSD A (low-rank supported).
Parameters
----------
A : (m,m) PSD matrix (may be singular)
rmax : Optional[int]
Keep at most rmax eigenpairs (largest).
eps_rel : float
Keep eigenvalues > eps_rel * max_eig.
Returns
-------
PSDPinvApply
.apply(X) computes A^+ @ X
"""
As = 0.5 * (A + A.T) # enforce symmetry numerically
lam_all, U_all = np.linalg.eigh(As) # ascending
lam_max = lam_all[-1] if lam_all.size else 0.0
if lam_max <= 0:
return PSDPinvApply(
U=np.zeros((A.shape[0], 0), dtype=A.dtype),
lam=np.zeros((0,), dtype=A.dtype),
cutoff=0.0
)
cutoff = eps_rel * float(lam_max)
keep = lam_all > cutoff
if rmax is not None and int(np.sum(keep)) > int(rmax):
idx = np.argsort(lam_all)[-int(rmax):]
keep2 = np.zeros_like(keep, dtype=bool)
keep2[idx] = True
keep = keep & keep2
lam = lam_all[keep].astype(A.dtype, copy=False)
U = U_all[:, keep].astype(A.dtype, copy=False)
return PSDPinvApply(U=U, lam=lam, cutoff=cutoff)
# ============================================================
# ---------------- Helper: build batch blocks ----------------
# ============================================================
def _build_blocks_batch(H_sel: np.ndarray,
M_sel: Optional[np.ndarray],
A_M_batch: np.ndarray,
H_batch: np.ndarray) -> np.ndarray:
"""
Construct block matrices for a batch of candidate sources.
"""
k = 0 if M_sel is None else M_sel.shape[0]
B = A_M_batch.shape[1]
m = k + 1
dtype = A_M_batch.dtype
if k == 0:
s_batch = np.sum(A_M_batch * H_batch, axis=0)
return s_batch.reshape(B, 1, 1).astype(dtype, copy=False)
a_batch = H_sel[:, :k].T @ A_M_batch
s_batch = np.sum(A_M_batch * H_batch, axis=0)
blocks = np.empty((B, m, m), dtype=dtype)
blocks[:, :k, :k] = M_sel[None, :, :]
blocks[:, :k, k:k+1] = a_batch.T.reshape(B, k, 1)
blocks[:, k:k+1, :k] = a_batch.T.reshape(B, 1, k)
blocks[:, k, k] = s_batch
return blocks
# ============================================================
# ---------------- Activity index: simplified ----------------
# ============================================================
def _simple_mai_batch(blocks_G: np.ndarray,
blocks_S: np.ndarray,
current_iter: int) -> np.ndarray:
"""
MAI = trace(G S⁻¹) - current_iter
"""
B, m, _ = blocks_S.shape
if m == 1:
return blocks_G[:, 0, 0] / blocks_S[:, 0, 0] - current_iter
I = np.eye(m)
def _process_block(i):
try:
invS = cho_solve(cho_factor(blocks_S[i], lower=False, check_finite=False),
I, check_finite=False)
except np.linalg.LinAlgError:
invS = pinv(blocks_S[i])
return np.trace(blocks_G[i] @ invS) - current_iter
results = Parallel(n_jobs=-1)(delayed(_process_block)(i) for i in range(B))
return np.array(results, dtype=float)
# ============================================================
# ---------------- Activity index: full versions --------------
# ============================================================
def _mai_mvp_batch(blocks_G: np.ndarray,
blocks_S: np.ndarray,
r: int,
current_iter: int) -> np.ndarray:
"""
MAI-MVP = sum of top-r eigenvalues of (G S⁻¹) - r
"""
B, m, _ = blocks_S.shape
if current_iter <= r:
return _simple_mai_batch(blocks_G, blocks_S, current_iter)
I = np.eye(m)
def _process_block(i):
if m == 1:
return blocks_G[i, 0, 0] / blocks_S[i, 0, 0] - r
try:
invS = cho_solve(cho_factor(blocks_S[i], lower=False, check_finite=False),
I, check_finite=False)
except np.linalg.LinAlgError:
invS = pinv(blocks_S[i])
M = blocks_G[i] @ invS
s, _ = np.linalg.eig(M)
s_sorted = np.sort(s)[::-1].real
return np.sum(s_sorted[:r]) - r
results = Parallel(n_jobs=-1)(delayed(_process_block)(i) for i in range(B))
return np.array(results, dtype=float)
# ============================================================
# ---------------- Dispatcher: choose function ---------------
# ============================================================
def _choose_activity_index_batch(localizer_to_use: str):
"""
Return the appropriate activity index function depending on the localizer name.
"""
if localizer_to_use == "mai":
return lambda Gs, Ss, Ts, current_iter, r: _simple_mai_batch(Gs, Ss, current_iter)
elif localizer_to_use == "mai_mvp":
return lambda Gs, Ss, Ts, current_iter, r: _mai_mvp_batch(Gs, Ss, r, current_iter)
else:
raise ValueError("Allowed: 'mai' and 'mai_mvp'")
# ============================================================
# ---------------- Main function -----------------------------
# ============================================================
[docs]
def get_activity_index(localizer_to_use: str,
H: np.ndarray,
R: np.ndarray,
N: np.ndarray,
n_sources_to_localize: int,
r: int,
batch_size: int = 16384,
show_progress: bool = True
) -> Tuple[List[int], np.ndarray, int, np.ndarray]:
"""
Greedy source localization algorithm.
"""
# --------------------------------------------------------
# Robust precompute for PSD low-rank R and N (NO Cholesky)
# --------------------------------------------------------
# A_R = R^+ H
# A_N = N^+ H
# A_TR = R^+ N R^+ H (same object as: pinv(R) @ N @ pinv(R) @ H)
# --------------------------------------------------------
R_pinv = psd_pinv_operator(R, rmax=None, eps_rel=1e-10)
N_pinv = psd_pinv_operator(N, rmax=None, eps_rel=1e-10)
A_R = R_pinv.apply(H)
A_N = N_pinv.apply(H)
A_TR = R_pinv.apply(N @ A_R)
# Initialization
n_channels, n_sources = H.shape
H_sel = np.zeros((n_channels, n_sources_to_localize), dtype=H.dtype)
S_sel = G_sel = T_sel = None
index_max: List[int] = []
act_values: List[float] = []
func = _choose_activity_index_batch(localizer_to_use)
selected_mask = np.zeros(n_sources, dtype=bool)
all_indices = np.arange(n_sources, dtype=int)
need_G = ("mai" in localizer_to_use) or (localizer_to_use == "mpz_mvp")
need_T = ("mpz" in localizer_to_use)
# Greedy selection loop
for outer_iter in range(n_sources_to_localize):
best_val = -np.inf
best_idx = None
if show_progress:
pbar = tqdm(total=n_sources, desc=f"iter {outer_iter+1}/{n_sources_to_localize}")
# Process candidates in batches
for start in range(0, n_sources, batch_size):
end = min(n_sources, start + batch_size)
batch_idx = all_indices[start:end]
valid = ~selected_mask[batch_idx]
if not np.any(valid):
if show_progress:
pbar.update(end - start)
continue
batch_idx = batch_idx[valid]
# Build block structures
A_R_batch = A_R[:, batch_idx]
H_batch = H[:, batch_idx]
blocks_S = _build_blocks_batch(H_sel[:, :len(index_max)], S_sel, A_R_batch, H_batch)
if need_G:
A_N_batch = A_N[:, batch_idx]
blocks_G = _build_blocks_batch(H_sel[:, :len(index_max)], G_sel, A_N_batch, H_batch)
else:
blocks_G = np.zeros_like(blocks_S)
if need_T:
A_TR_batch = A_TR[:, batch_idx]
blocks_T = _build_blocks_batch(H_sel[:, :len(index_max)], T_sel, A_TR_batch, H_batch)
else:
blocks_T = np.zeros_like(blocks_S)
# Compute activity index values for the batch
vals = func(blocks_G, blocks_S, blocks_T, current_iter=outer_iter + 1, r=r)
local_best_idx = int(np.argmax(vals))
local_best_val = float(vals[local_best_idx])
# Update global best
if local_best_val > best_val:
best_val = local_best_val
best_idx = int(batch_idx[local_best_idx])
if show_progress:
pbar.update(end - start)
if show_progress:
pbar.close()
if best_idx is None:
break
# Save chosen index and its value
index_max.append(best_idx)
act_values.append(float(best_val))
sel_pos = len(index_max) - 1
H_sel[:, sel_pos:sel_pos+1] = H[:, best_idx:best_idx+1]
selected_mask[best_idx] = True
# Incrementally update block matrices S_sel / G_sel / T_sel
aR_chosen = A_R[:, best_idx:best_idx+1]
S_sel = _build_blocks_batch(H_sel[:, :sel_pos+1], S_sel, aR_chosen,
H[:, best_idx:best_idx+1])[0]
if need_G:
aN_chosen = A_N[:, best_idx:best_idx+1]
G_sel = _build_blocks_batch(H_sel[:, :sel_pos+1], G_sel, aN_chosen,
H[:, best_idx:best_idx+1])[0]
if need_T:
aTR_chosen = A_TR[:, best_idx:best_idx+1]
T_sel = _build_blocks_batch(H_sel[:, :sel_pos+1], T_sel, aTR_chosen,
H[:, best_idx:best_idx+1])[0]
# Collect selected sources
H_res = H[:, sorted(index_max)].copy() if index_max else np.zeros((n_channels, 0), dtype=H.dtype)
# Final summary
print("\n[Activity Index Result]")
print(f" Selected indices (index_max): {index_max}")
print(f" Index max values: {np.array(act_values, dtype=float)}")
print(f" Rank parameter (r): {r}\n")
return index_max, np.array(act_values, dtype=float), r, H_res