Source code for mvpure_py.simulation.validate

"""Validation of results achieved on synthetic data."""
# Author: Julia Jurkowska
import matplotlib.pyplot as plt
import mne
import numpy as np
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import mean_squared_error

from ..utils import (
    vertices_to_coordinates,
    transform_leadfield_indices_to_vertices,
    subset_forward
)

from ..localizer import localize


[docs] def localization_error( vertices_true: list[list[int]], localized_vertices: list[list[int]], src: mne.SourceSpaces ) -> dict: """ Compute localization error between true and estimated source vertices. The function converts vertex indices to 3D coordinated using the provided source space and computes pairwise Euclidean distances between true and localized sources. The Hungarian algorithm is used to determine the optimal one-to-one assignment that minimizes total distance error. Parameters ---------- vertices_true : list[list[int]] Ground-truth vertices for each hemisphere in the form ``[lh_vertices, rh_vertices]`` localized_vertices : list[list[int]] Localized vertices for each hemisphere in the form ``[lh_vertices, rh_vertices]`` src : mne.SourceSpaces Source space used to map vertices to 3D coordinates Returns ------- results : dict Dictionary containing localization error statistics: - ``mean_error`` : float Mean distance error (mm). - ``median_error`` : float Median distance error (mm). - ``sum_error`` : float Sum of all matched vertex errors (mm). - ``max_error`` : float Maximum error among matched vertices (mm). - ``min_error`` : float Minimum error among matched vertices (mm). - ``errors_per_vertex`` : ndarray Distance error for each matched vertex (mm). """ pos_true = vertices_to_coordinates(vertices_true, src) pos_loc = vertices_to_coordinates(localized_vertices, src) if len(pos_true) == 0 or len(pos_loc) == 0: raise ValueError("No active vertices found in ground truth or estimated set.") # Pairwise distance matrix dist = cdist(pos_true, pos_loc, metric="euclidean") # Hungarian algorithm for optimal matching row_ind, col_ind = linear_sum_assignment(dist) matched_distances = dist[row_ind, col_ind] # Unit conversion to mm matched_distances = matched_distances * 1000.0 results = { "mean_error": float(np.mean(matched_distances)), "median_error": float(np.median(matched_distances)), "sum_error": float(np.sum(matched_distances)), "max_error": float(np.max(matched_distances)), "min_error": float(np.min(matched_distances)), "errors_per_vertex": matched_distances, } return results
[docs] def evaluate_localization_for_each_rank( subject: str, subjects_dir: str, n_sources: int, R: mne.Covariance, N: mne.Covariance, forward: mne.Forward, true_vertices: list[list[int]], localizer_to_use: list | str, plot_correct_sources_by_rank: bool = True, plot_sum_error_by_rank: bool = False, plot_mean_error_by_rank: bool = True, show_progress: bool = True ): """ Evaluate source localization performance for different rank parameters. The function iteratively runs a source localization method for ranks ``1 ... n_sources`` and compared the localized vertices with ground-truth vertices. For each rank, the number of correctly localized sources and localization error metrics are computed. Parameters ---------- subject : str Subject identifier the analysis is performed for subjects_dir : str Directory where ``subject`` folder is being stored n_sources : int Number of sources to localize R : mne.Covariance Data covariance matrix N : mne.Covariance Noise covariance matrix forward : mne.Forward Forward solution used for source localization. true_vertices : list[list[int]] Ground-truth vertices in the format ``[lh_vertices, rh_vertices]``. localizer_to_use : list | str Type of localizer index to use. Options are: 'mai', 'mpz', 'mai_mvp', 'mpz_mvp'. plot_correct_sources_by_rank : bool, optional Whether to plot the number of correctly localized sources per rank. Default is True. plot_sum_error_by_rank : bool, optional Whether to plot the sum of localization errors per rank. Default is False. plot_mean_error_by_rank : bool, optional Whether to plot the mean localization error per rank. Default is True. show_progress : bool, optional Whether to display progress during localization. Default is True. Returns ------- summary_dict : dict Dictionary summarizing results for each rank. Each entry contains: - ``localized`` : localization result object - ``n_correctly_localized`` : int Number of correctly identified vertices. - ``correctly_localized`` : list[list[int]] Intersection between estimated and true vertices. - ``error_info`` : dict Localization error statistics returned by :func:`localization_error`. """ summary_dict = {} for r in range(1, n_sources + 1): rank_label = f"rank_{r}" summary_dict[rank_label] = {} # Localize for rank r locs_temp = localize( subject=subject, subjects_dir=subjects_dir, localizer_to_use=localizer_to_use, n_sources_to_localize=n_sources, R=R, N=N, forward=forward, r=r, show_progress=show_progress ) lh_vert, lh_idx, rh_vert, rh_idx = transform_leadfield_indices_to_vertices( lf_idx=locs_temp["sources"], src=forward['src'], hemi="both", include_mapping=True ) locs_temp.add_vertices_info( lh_vertices=lh_vert, lh_indices=lh_idx, rh_vertices=rh_vert, rh_indices=rh_idx ) summary_dict[rank_label]['localized'] = locs_temp lh_intersection = list(set(lh_vert).intersection(true_vertices[0])) rh_intersection = list(set(rh_vert).intersection(true_vertices[1])) summary_dict[rank_label]["n_correctly_localized"] = len(lh_intersection) + len(rh_intersection) summary_dict[rank_label]['correctly_localized'] = [sorted(lh_intersection), sorted(rh_intersection)] summary_dict[rank_label]["error_info"] = localization_error( vertices_true=true_vertices, localized_vertices=locs_temp['vertno_in_use'], src=forward['src'] ) if plot_correct_sources_by_rank: _plot_correct_sources_by_rank(summary_dict) if plot_sum_error_by_rank: _plot_sum_error_by_rank(summary_dict) if plot_mean_error_by_rank: _plot_mean_error_by_rank(summary_dict) return summary_dict
[docs] def evaluate_reconstruction(summary_dict: dict, signal, sim_stc: mne.SourceEstimate, fwd: mne.Forward, R: mne.Covariance, N: mne.Covariance, reg=0.05, pick_ori=None, rank=None, **kwargs): """ Evaluate reconstruction quality using LCMV beamforming. For each rank entr in the localization summary, two reconstructions are computed: - LCMV using the full forward model - LCMV using a forward model restricted to the localized sources. Reconstruction performance is evaluated using the mean squared error (MSE) between the simulated source activity and the reconstructed activity for the correctly localized vertices. Parameters ---------- summary_dict : dict Localization summary returned by :func:`evaluate_localization_for_each_rank`. signal : mne.Epochs Sensor-space data used for beamforming. sim_stc : mne.SourceEstimate Simulated ground-truth source time course. fwd : mne.Forward Full forward solution. R : mne.Covariance Data covariance N : mne.Covariance Noise covariance reg : float, optional Regularization parameter for LCMV. Default is 0.05. pick_ori : str | None, optional Orientation selection for beamformer weights. rank : int | None, optional Rank used in beamformer computation. **kwargs Additional arguments passed to ``mne.beamformer.make_lcmv``. """ for key_rank, summary in summary_dict.items(): print(f"{key_rank}") # LCMV on full (unchanged) mne.Forward lcmv_org = mne.beamformer.make_lcmv( signal.info, fwd, data_cov=R, reg=reg, noise_cov=N, pick_ori=pick_ori, weight_norm=None, rank=rank, **kwargs ) stc_lcmv_org = mne.beamformer.apply_lcmv(signal, lcmv_org) # LCMV only on correctly localized sources new_fwd = subset_forward( old_fwd=fwd, localized=summary['localized'], hemi="both" ) lcmv_subset = mne.beamformer.make_lcmv( signal.info, new_fwd, data_cov=R, reg=reg, noise_cov=N, pick_ori=pick_ori, weight_norm=None, rank=rank, **kwargs ) stc_lcmv_subset = mne.beamformer.apply_lcmv(signal, lcmv_subset) # Crop Source Estimate from simulation sim_stc_cropped = sim_stc.copy().crop(tmin=stc_lcmv_subset.times[0], tmax=stc_lcmv_subset.times[-1]) # check assert np.array_equal(stc_lcmv_subset.times, stc_lcmv_org.times) assert np.array_equal(stc_lcmv_subset.times, sim_stc_cropped.times) sim_stc_soi, _ = get_stc_data_for_vertices(sim_stc_cropped, vertices=summary['correctly_localized']) stc_lcmv_org_soi, _ = get_stc_data_for_vertices(stc_lcmv_org, vertices=summary['correctly_localized']) stc_lcmv_subset_soi, _ = get_stc_data_for_vertices(stc_lcmv_subset, vertices=summary['correctly_localized']) mse_sim_lcmv_org = mean_squared_error(sim_stc_soi, stc_lcmv_org_soi) print(f"MSE simulation <-> full LCMV: {mse_sim_lcmv_org}") mse_sim_lcmv_subset = mean_squared_error(sim_stc_soi, stc_lcmv_subset_soi) print(f"MSE simulation <-> subset LCMV: {mse_sim_lcmv_subset}")
[docs] def get_stc_data_for_vertices(stc: mne.SourceEstimate, vertices: list[list[int]]): """ Extract source time courses for specific vertices from a SourceEstimate. Parameters ---------- stc : mne.SourceEstimate Source estimate containing source time courses. vertices : [list[list[int]] Vertices of interest in the format ``[lh_vertices, rh_vertices]`` Returns ------- data : ndarray Source time courses corresponding to the selected vertices with shape ``(n_vertices, n_times)``. indices : ndarray Indices of the selected vertices within ``stc.data``. """ lh_vertices, rh_vertices = vertices lh_vertno, rh_vertno = stc.vertices lh_len = len(lh_vertno) lh_map = {v: i for i, v in enumerate(lh_vertno)} rh_map = {v: i + lh_len for i, v in enumerate(rh_vertno)} indices = [] for v in lh_vertices: if v in lh_map: indices.append(lh_map[v]) for v in rh_vertices: if v in rh_map: indices.append(rh_map[v]) indices = np.array(indices) return stc.data[indices], indices
[docs] def compare_with_strongest_sources_lcmv( signal, n_sources: int, true_vertices: list[list[int]], forward: mne.Forward, R: mne.Covariance, N: mne.Covariance, reg: float = 0.05, pick_ori=None, ): """ Identify the strongest sources based on LCMV reconstruction and compare them to ground-truth sources. Parameters ---------- signal : Sensor space data n_sources : int Number of sources to localize. true_vertices : list[list[int]] Ground-truth vertices in the format ``[lh_vertices, rh_vertices]``. forward : mne.Forward Forward solution used for creating ground-truth data. R : mne.Covariance Data covariance matrix N : mne.Covariance Noise covariance matrix reg : float The regularization for the whitened data covariance. Default to 0.05. pick_ori : Orientation specification. Default to None. """ lcmv = mne.beamformer.make_lcmv( signal.info, forward, R, reg=reg, noise_cov=N, pick_ori=pick_ori, weight_norm="nai", rank=None ) stc = mne.beamformer.apply_lcmv(signal, lcmv) # Maximum NAI value for each source across time peak_nai_per_source = np.max(np.abs(stc.data), axis=1) # Indices of strongest sources top_idx = np.argsort(peak_nai_per_source)[-n_sources:][::-1] top_vertices = [[], []] # [[lh], [rh]] n_lh = len(stc.vertices[0]) for idx in top_idx: if idx < n_lh: top_vertices[0].append(stc.vertices[0][idx]) else: top_vertices[1].append(stc.vertices[1][idx - n_lh]) lh_intersection = list(set(top_vertices[0]).intersection(true_vertices[0])) rh_intersection = list(set(top_vertices[1]).intersection(true_vertices[1])) n_correctly_localized = len(lh_intersection) + len(rh_intersection) error_info = localization_error( vertices_true=true_vertices, localized_vertices=top_vertices, src=forward['src'] ) return top_vertices, n_correctly_localized, error_info
def _plot_correct_sources_by_rank( evaluation_summary_dict: dict, color: str = "indigo" ): """ Plot the number of correctly localized sources as a function of rank. """ correctly_localized = [evaluation_summary_dict[key]["n_correctly_localized"] for key in list(evaluation_summary_dict.keys())] plt.scatter( x=np.arange(1, len(correctly_localized)+1), y=correctly_localized, color=color ) plt.xlim((0, len(correctly_localized)+1)) plt.xticks(np.arange(1, len(correctly_localized)+1, 1), np.arange(1, len(correctly_localized)+1, 1)) plt.ylim((0, len(correctly_localized) + 1)) plt.yticks(np.arange(1, len(correctly_localized) + 1, 1), np.arange(1, len(correctly_localized) + 1, 1)) plt.grid(alpha=0.3) plt.xlabel("Rank parameter") plt.ylabel("Number of correctly localized sources") plt.show() def _plot_sum_error_by_rank( evaluation_summary_dict: dict, color: str = "indigo" ): """ Plot the sum of localization errors as a function of rank. """ error_sum = [evaluation_summary_dict[key]["error_info"]["sum_error"] for key in list(evaluation_summary_dict.keys())] plt.scatter( x=np.arange(1, len(error_sum) + 1), y=error_sum, color=color ) plt.xlim((0, len(error_sum) + 1)) plt.xticks(np.arange(1, len(error_sum) + 1, 1), np.arange(1, len(error_sum) + 1, 1)) plt.grid(alpha=0.3) plt.xlabel("Rank parameter") plt.ylabel("Sum of distance error [mm]") plt.show() def _plot_mean_error_by_rank( evaluation_summary_dict: dict, color: str = "indigo" ): """ Plot the mean localization error as a function of rank. """ mean_sum = [evaluation_summary_dict[key]["error_info"]["mean_error"] for key in list(evaluation_summary_dict.keys())] plt.scatter( x=np.arange(1, len(mean_sum) + 1), y=mean_sum, color=color ) plt.xlim((0, len(mean_sum) + 1)) plt.xticks(np.arange(1, len(mean_sum) + 1, 1), np.arange(1, len(mean_sum) + 1, 1)) plt.grid(alpha=0.3) plt.xlabel("Rank parameter") plt.ylabel("Mean of distance error [mm]") plt.show()