"""
Sequential GMM thresholding implementation.
This module provides the SequentialGMM class for
performing iterative refinement of categorical labels through multiple sequential
thresholding operations.
Classes:
SequentialGMM: Sequential refinement class
"""
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
from pathlib import Path
import warnings
import anndata as ad
import matplotlib as mpl
from matplotlib import colors as mpl_colors
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
from .base import (
_GaussianMixtureModelInfo,
_DecisionBoundariesModel,
_SingleThresholdingEventModel,
GaussianMixtureModelBase,
_validate_save_path,
)
from .single import GMMThresholding
[docs]
class SequentialGMM(GaussianMixtureModelBase):
"""
Sequential GMM thresholding for iterative population refinement.
This class enables performing multiple sequential GMM thresholding operations
on subsets of cells, where each operation refines a specific categorical label
from a previous thresholding event. This is useful for hierarchical cell type
classification or iterative gating strategies.
Unlike GMMThresholding which thresholds a single feature once,
this class allows:
- Initial thresholding on entire dataset
- Refinement of specific label values through additional thresholding
- Tracking operation provenance (parent-child relationships)
- Multiple operations stored in a single .uns key
Attributes
----------
adata : ad.AnnData
A copy of the input AnnData object, modified during processing.
thresholding_events_key : str
Key in adata.uns for storing all operations.
gmm_kwargs : Dict
Default GMM kwargs (can be overridden per operation).
random_state : int
Random state for reproducibility.
Examples
--------
Example workflow::
# Initialize
seq_gmm = SequentialGMM(
adata=adata,
thresholding_events_key='sequential_thresholding'
)
# Create initial labels on entire dataset
seq_gmm.threshold_entire_dataset(
feature='DNA_content',
label_obs_save_str='cell_cycle',
n_components=2,
ordered_labels=['Low', 'High'],
operation_name='DNA_threshold'
)
# Refine 'Low' cells only
seq_gmm.refine_labels_with_gmm(
feature='Plk1',
obs_label='cell_cycle',
value_to_refine='Low',
n_components=2,
ordered_labels=['Low_neg', 'Low_pos'],
operation_name='Plk1_refinement'
)
# Get modified adata
adata = seq_gmm.return_adata()
"""
def __init__(
self,
adata: ad.AnnData,
thresholding_events_key: str = 'sequential_gmm_thresholding_events',
gmm_kwargs: Optional[dict] = None,
random_state: int = 42,
):
"""
Initialize sequential thresholding object.
Parameters
----------
adata : ad.AnnData
Annotated data matrix (observations x features).
thresholding_events_key : str, optional
The key in `adata.uns` where all thresholding event information will be stored.
Defaults to 'sequential_gmm_thresholding_events'.
gmm_kwargs : Optional[Dict], optional
Default keyword arguments to pass to `sklearn.mixture.GaussianMixture`.
Can be overridden per operation. Defaults to None (becomes {}).
random_state : int, optional
Random state for reproducibility. Defaults to 42.
Raises
------
TypeError
If `adata` is not an AnnData object.
TypeError
If `adata.X` is not a numeric type.
TypeError
If `thresholding_events_key` is not a string.
ValueError
If `thresholding_events_key` is an empty string.
TypeError
If `adata.uns[thresholding_events_key]` exists but is not an OrderedDict.
TypeError
If `gmm_kwargs` is not a dictionary.
"""
# Validate adata
if not isinstance(adata, ad.AnnData):
raise TypeError("adata must be an AnnData.AnnData object.")
elif np.issubdtype(adata.X.dtype, np.object_):
raise TypeError(
"adata.X must be a numeric type. Please convert the data to a numeric type."
)
# Validate thresholding_events_key
if not isinstance(thresholding_events_key, str):
raise TypeError("thresholding_events_key must be a string.")
elif not thresholding_events_key:
raise ValueError("thresholding_events_key cannot be an empty string.")
# Create or validate .uns key
if thresholding_events_key not in adata.uns:
adata.uns[thresholding_events_key] = OrderedDict()
elif not isinstance(adata.uns[thresholding_events_key], OrderedDict):
raise TypeError(
f"The '{thresholding_events_key}' key in the AnnData object's `.uns` attribute must be an OrderedDict."
)
# Validate gmm_kwargs
if gmm_kwargs is None:
gmm_kwargs = {}
elif not isinstance(gmm_kwargs, dict):
raise TypeError("gmm_kwargs must be a dictionary.")
# Store attributes
self.adata = adata.copy()
self.thresholding_events_key = thresholding_events_key
self.gmm_kwargs = gmm_kwargs
self.random_state = random_state
[docs]
def threshold_entire_dataset(
self,
feature: str,
label_obs_save_str: str,
n_components: int,
ordered_labels: List[str],
manual_thresholds: Optional[List[Union[float, int]]] = None,
duplicate_labels: bool = False,
operation_name: Optional[str] = None,
layer: Optional[str] = None,
gmm_kwargs: Optional[dict] = None,
overwrite: bool = False,
) -> None:
"""
Threshold entire dataset to create initial categorical labels.
This method creates a new obs column with categorical labels based on
GMM thresholding of a single feature across all cells. It's a wrapper
around GMMThresholding that stores results in the
sequential thresholding framework.
Parameters
----------
feature : str
Feature name to threshold on (must exist in adata.var_names).
label_obs_save_str : str
New column name in adata.obs for labels.
n_components : int
Number of GMM components to fit.
ordered_labels : List[str]
Labels to assign (length = n_components).
manual_thresholds : Optional[List[Union[float, int]]], optional
Manual threshold values. If None, calculated automatically from GMM.
Length must be n_components - 1. Defaults to None.
duplicate_labels : bool, optional
Allow duplicate labels for label collapsing. Defaults to False.
operation_name : str
Required name for tracking this operation.
layer : Optional[str], optional
Layer to use for data access. If None, uses adata.X. Defaults to None.
gmm_kwargs : Optional[dict], optional
GMM kwargs for this operation. Overrides default if provided. Defaults to None.
overwrite : bool, optional
If True, allows overwriting an existing operation with the same name.
Useful for updating n_components or thresholds. Defaults to False.
Raises
------
ValueError
If operation_name is None or empty.
KeyError
If operation_name already exists in .uns and overwrite=False.
Notes
-----
Other exceptions raised by GMMThresholding.
Examples
--------
::
seq_gmm.threshold_entire_dataset(
feature='DNA_content',
label_obs_save_str='cell_cycle',
n_components=3,
ordered_labels=['G0', 'G1', 'S'],
operation_name='DNA_initial_threshold'
)
"""
# Validate operation_name
if operation_name is None or not operation_name:
raise ValueError("operation_name is required and cannot be empty.")
if operation_name in self.adata.uns[self.thresholding_events_key]:
if not overwrite:
raise KeyError(
f"Operation name '{operation_name}' already exists in "
f"adata.uns['{self.thresholding_events_key}']. Use overwrite=True to update it."
)
else:
warnings.warn(
f"Overwriting existing operation '{operation_name}'.",
UserWarning,
stacklevel=2
)
# Use default gmm_kwargs if not provided
if gmm_kwargs is None:
gmm_kwargs = self.gmm_kwargs
# Create temporary single thresholding instance
# Use a temporary .uns key to avoid conflicts
temp_key = f'_temp_{operation_name}'
gmm_single = GMMThresholding(
adata=self.adata,
feature=feature,
label_obs_save_str=label_obs_save_str,
thresholding_events_key=temp_key,
layer=layer,
gmm_kwargs=gmm_kwargs,
random_state=self.random_state,
)
# Fit and categorize
gmm_single.fit(n_components=n_components)
gmm_single.categorize_samples(
ordered_labels=ordered_labels,
manual_thresholds=manual_thresholds,
duplicate_labels=duplicate_labels,
)
# Extract results from temporary instance
self.adata = gmm_single.return_adata()
# Move operation data from temp key to our key with metadata
# Single class stores with feature name as key
operation_data = self.adata.uns[temp_key][feature]
# Add sequential-specific metadata
operation_data['operation_type'] = 'standard'
operation_data['parent_operation'] = None
operation_data['refined_from_labels'] = None
operation_data['layer'] = layer
# Capture cell counts immediately after this operation
cell_counts_after_operation = {str(k): int(v) for k, v in self.adata.obs[label_obs_save_str].value_counts().items()}
operation_data['cell_counts_after_operation'] = cell_counts_after_operation
# Store in our key with the operation_name
self.adata.uns[self.thresholding_events_key][operation_name] = operation_data
# Clean up temp key
del self.adata.uns[temp_key]
[docs]
def refine_labels_with_gmm(
self,
feature: str,
obs_label: str,
value_to_refine: str,
n_components: int,
ordered_labels: List[str],
duplicate_labels: bool = False,
operation_name: Optional[str] = None,
layer: Optional[str] = None,
gmm_kwargs: Optional[dict] = None,
overwrite: bool = False,
) -> None:
"""
Refine existing categorical labels by thresholding a subset with GMM.
Modifies adata.obs[obs_label] in-place (within the copy), replacing cells
with value_to_refine with new labels based on GMM thresholding.
Parameters
----------
feature : str
Feature to threshold on (e.g., 'Plk1').
obs_label : str
Obs column to modify in-place (e.g., 'cell_cycle').
value_to_refine : str
Which label value to refine (e.g., 'G0').
n_components : int
Number of GMM components to fit.
ordered_labels : List[str]
New labels to assign (e.g., ['G0_low', 'G0_high']).
duplicate_labels : bool, optional
Allow duplicate labels for label collapsing. Defaults to False.
operation_name : str
Required name for tracking this operation.
layer : Optional[str], optional
Layer to use for data access. If None, uses adata.X. Defaults to None.
gmm_kwargs : Optional[dict], optional
GMM kwargs for this operation. Overrides default if provided. Defaults to None.
overwrite : bool, optional
If True, allows overwriting an existing operation with the same name.
Useful for updating n_components or thresholds. Defaults to False.
Raises
------
ValueError
If operation_name is None or empty.
KeyError
If operation_name already exists in .uns and overwrite=False.
KeyError
If obs_label doesn't exist in adata.obs.
ValueError
If value_to_refine is not present in adata.obs[obs_label].
ValueError
If no cells have the value_to_refine.
Examples
--------
::
# Before: adata.obs['cell_cycle'] = ['G0', 'G0', 'G1', 'S', 'G0']
seq_gmm.refine_labels_with_gmm(
feature='Plk1',
obs_label='cell_cycle',
value_to_refine='G0',
n_components=2,
ordered_labels=['G0_low', 'G0_high'],
operation_name='Plk1_G0_refinement'
)
# After: adata.obs['cell_cycle'] = ['G0_low', 'G0_high', 'G1', 'S', 'G0_low']
"""
# Validate operation_name
if operation_name is None or not operation_name:
raise ValueError("operation_name is required and cannot be empty.")
if operation_name in self.adata.uns[self.thresholding_events_key]:
if not overwrite:
raise KeyError(
f"Operation name '{operation_name}' already exists in "
f"adata.uns['{self.thresholding_events_key}']. Use overwrite=True to update it."
)
else:
warnings.warn(
f"Overwriting existing operation '{operation_name}'.",
UserWarning,
stacklevel=2
)
# Validate obs_label exists
if obs_label not in self.adata.obs.columns:
raise KeyError(
f"obs_label '{obs_label}' not found in adata.obs. "
f"Available columns: {list(self.adata.obs.columns)}"
)
# Partition data by label
mask, subset_data = self._partition_data_by_label(
obs_label=obs_label,
value_to_refine=value_to_refine,
feature=feature,
layer=layer,
)
# Use default gmm_kwargs if not provided
if gmm_kwargs is None:
gmm_kwargs = self.gmm_kwargs.copy()
# Ensure random_state is in gmm_kwargs (but don't override if explicitly provided)
if 'random_state' not in gmm_kwargs:
gmm_kwargs['random_state'] = self.random_state
# Fit GMM on subset only
gmm = GaussianMixture(
n_components=n_components,
**gmm_kwargs
)
gmm.fit(subset_data.reshape(-1, 1))
# Extract GMM info
gmm_info = _GaussianMixtureModelInfo(
gmm_kwargs=gmm_kwargs,
means=gmm.means_.flatten(),
covs=gmm.covariances_.flatten(),
weights=gmm.weights_,
n_components=n_components,
data_probs=gmm.predict_proba(subset_data.reshape(-1, 1)),
)
# Handle duplicate labels if needed
if duplicate_labels:
ordered_labels_processed, condensed_data_probs = self._handle_duplicate_labels(
ordered_labels, gmm_info.data_probs
)
gmm_info.condensed_data_probs = condensed_data_probs
else:
ordered_labels_processed = ordered_labels
# Calculate decision boundaries from probabilities
# Convert data_probs back to numpy array (Pydantic stores as list)
probs_array = np.array(gmm_info.condensed_data_probs if duplicate_labels else gmm_info.data_probs)
decision_boundaries = self._calculate_decision_boundaries_from_probs(
feature_values=subset_data,
data_probs=probs_array,
ordered_labels=ordered_labels_processed,
)
# Assign new labels to subset cells
new_labels = self._assign_labels_from_thresholds(
data=subset_data,
thresholds=decision_boundaries.thresholds,
ordered_labels=ordered_labels_processed,
)
# Update labels in-place
self._update_labels_in_place(
obs_label=obs_label,
mask=mask,
new_labels=new_labels,
)
# Capture cell counts immediately after this operation
cell_counts_after_operation = {str(k): int(v) for k, v in pd.Series(new_labels).value_counts().items()}
# Store operation metadata
internal_data = _SingleThresholdingEventModel(
gmm_info=gmm_info,
ordered_gmm_labels=ordered_labels,
decision_boundaries=decision_boundaries,
condensed_labels=ordered_labels_processed if duplicate_labels else None,
feature_name=feature,
gmm_obs_label=obs_label,
)
self._store_refinement_operation(
operation_name=operation_name,
internal_data=internal_data,
obs_label=obs_label,
value_to_refine=value_to_refine,
operation_type='refinement',
layer=layer,
cell_counts_after_operation=cell_counts_after_operation,
)
[docs]
def refine_labels_with_manual_thresholds(
self,
feature: str,
obs_label: str,
value_to_refine: str,
manual_thresholds: List[Union[float, int]],
ordered_labels: List[str],
operation_name: Optional[str] = None,
layer: Optional[str] = None,
overwrite: bool = False,
) -> None:
"""
Refine existing categorical labels using manual thresholds.
Similar to refine_labels_with_gmm() but uses explicit threshold values
instead of fitting a GMM.
Parameters
----------
feature : str
Feature to threshold on.
obs_label : str
Obs column to modify in-place.
value_to_refine : str
Which label value to refine.
manual_thresholds : List[Union[float, int]]
Threshold values. Length must be len(ordered_labels) - 1.
ordered_labels : List[str]
New labels to assign.
operation_name : str
Required name for tracking this operation.
layer : Optional[str], optional
Layer to use for data access. If None, uses adata.X. Defaults to None.
overwrite : bool, optional
If True, allows overwriting an existing operation with the same name.
Useful for updating thresholds. Defaults to False.
Raises
------
ValueError
If operation_name is None or empty.
KeyError
If operation_name already exists in .uns and overwrite=False.
KeyError
If obs_label doesn't exist in adata.obs.
ValueError
If value_to_refine is not present in adata.obs[obs_label].
ValueError
If no cells have the value_to_refine.
ValueError
If len(manual_thresholds) != len(ordered_labels) - 1.
Examples
--------
::
seq_gmm.refine_labels_with_manual_thresholds(
feature='Plk1',
obs_label='cell_cycle',
value_to_refine='G0',
manual_thresholds=[1.5],
ordered_labels=['G0_low', 'G0_high'],
operation_name='Plk1_G0_manual'
)
"""
# Validate operation_name
if operation_name is None or not operation_name:
raise ValueError("operation_name is required and cannot be empty.")
if operation_name in self.adata.uns[self.thresholding_events_key]:
if not overwrite:
raise KeyError(
f"Operation name '{operation_name}' already exists in "
f"adata.uns['{self.thresholding_events_key}']. Use overwrite=True to update it."
)
else:
warnings.warn(
f"Overwriting existing operation '{operation_name}'.",
UserWarning,
stacklevel=2
)
# Validate obs_label exists
if obs_label not in self.adata.obs.columns:
raise KeyError(
f"obs_label '{obs_label}' not found in adata.obs. "
f"Available columns: {list(self.adata.obs.columns)}"
)
# Validate manual_thresholds length
if len(manual_thresholds) != len(ordered_labels) - 1:
raise ValueError(
f"Number of thresholds ({len(manual_thresholds)}) must be "
f"len(ordered_labels) - 1 = {len(ordered_labels) - 1}"
)
# Partition data by label
mask, subset_data = self._partition_data_by_label(
obs_label=obs_label,
value_to_refine=value_to_refine,
feature=feature,
layer=layer,
)
# Assign new labels based on manual thresholds
new_labels = self._assign_labels_from_thresholds(
data=subset_data,
thresholds=manual_thresholds,
ordered_labels=ordered_labels,
)
# Update labels in-place
self._update_labels_in_place(
obs_label=obs_label,
mask=mask,
new_labels=new_labels,
)
# Capture cell counts immediately after this operation
cell_counts_after_operation = {str(k): int(v) for k, v in pd.Series(new_labels).value_counts().items()}
# Store operation metadata (no GMM info for manual thresholds)
decision_boundaries = _DecisionBoundariesModel(thresholds=manual_thresholds)
internal_data = _SingleThresholdingEventModel(
gmm_info=None, # No GMM for manual thresholds
ordered_gmm_labels=ordered_labels,
decision_boundaries=decision_boundaries,
condensed_labels=None,
feature_name=feature,
gmm_obs_label=obs_label,
)
self._store_refinement_operation(
operation_name=operation_name,
internal_data=internal_data,
obs_label=obs_label,
value_to_refine=value_to_refine,
operation_type='refinement_manual',
layer=layer,
cell_counts_after_operation=cell_counts_after_operation,
)
def _partition_data_by_label(
self,
obs_label: str,
value_to_refine: str,
feature: str,
layer: Optional[str]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Extract subset of data for cells with specific label value.
Parameters
----------
obs_label : str
Obs column name to check.
value_to_refine : str
Label value to select.
feature : str
Feature name to extract data for.
layer : Optional[str]
Layer to use (None = .X).
Returns
-------
mask : np.ndarray
Boolean array indicating which cells to refine.
feature_data : np.ndarray
Feature values for those cells (1D array).
Raises
------
ValueError
If value_to_refine is not present in obs_label.
ValueError
If no cells have the value_to_refine.
"""
# Check if value_to_refine exists
unique_values = self.adata.obs[obs_label].unique()
if value_to_refine not in unique_values:
# Check if this operation exists in metadata and may have been previously refined
error_msg = (
f"value_to_refine '{value_to_refine}' not found in adata.obs['{obs_label}'].\n"
f"Available values: {list(unique_values)}"
)
# Check if operation with this value_to_refine exists in history
if hasattr(self.adata, 'uns') and self.thresholding_events_key in self.adata.uns:
metadata = self.adata.uns[self.thresholding_events_key]
for op_name, op_data in metadata.items():
if op_data.get('label_to_refine') == value_to_refine:
refined_labels = op_data.get('ordered_gmm_labels', [])
# Check if those refined labels exist now
existing_refined = [lbl for lbl in refined_labels if lbl in unique_values]
if existing_refined:
error_msg += (
f"\n\nNote: Operation '{op_name}' previously refined '{value_to_refine}' "
f"into {refined_labels}.\n"
f"These labels currently exist: {existing_refined}\n\n"
f"To re-run this analysis:\n"
f" 1. Restart from the initialization of the SequentialGMM object, or\n"
f" 2. Use the exploratory plotting functions to visualize current labels:\n"
f" seq_gmm.plot_feature_distribution_exploratory(...)\n"
f" seq_gmm.plot_feature_strip_plot_exploratory(...)"
)
break
raise ValueError(error_msg)
# Create mask for cells with value_to_refine
mask = self.adata.obs[obs_label] == value_to_refine
# Check if any cells match
if not mask.any():
raise ValueError(
f"No cells found with value '{value_to_refine}' in adata.obs['{obs_label}']"
)
# Extract feature data for subset
if layer is None:
feature_data = self.adata[mask, feature].X
else:
feature_data = self.adata[mask, feature].layers[layer]
# Ensure 1D array
if hasattr(feature_data, 'toarray'):
feature_data = feature_data.toarray()
feature_data = np.asarray(feature_data).flatten()
return mask, feature_data
def _assign_labels_from_thresholds(
self,
data: np.ndarray,
thresholds: List[Union[float, int]],
ordered_labels: List[str],
) -> np.ndarray:
"""
Assign categorical labels based on threshold values.
Parameters
----------
data : np.ndarray
1D array of feature values.
thresholds : List[Union[float, int]]
List of threshold values (sorted low to high).
ordered_labels : List[str]
Labels corresponding to threshold bins.
Returns
-------
np.ndarray
Array of labels (same length as data).
Raises
------
ValueError
If number of labels doesn't match number of thresholds.
"""
# Validate that we have the right number of labels
expected_labels = len(thresholds) + 1
if len(ordered_labels) != expected_labels:
raise ValueError(
f"Number of labels ({len(ordered_labels)}) must equal number of thresholds + 1 ({expected_labels}). "
f"Thresholds: {thresholds}, Labels: {ordered_labels}"
)
# Initialize with first label
labels = np.full(len(data), ordered_labels[0], dtype=object)
# Assign labels based on thresholds
for i, threshold in enumerate(thresholds):
labels[data > threshold] = ordered_labels[i + 1]
return labels
def _update_labels_in_place(
self,
obs_label: str,
mask: np.ndarray,
new_labels: np.ndarray
) -> None:
"""
Update obs column in-place for masked cells.
Parameters
----------
obs_label : str
Obs column to modify.
mask : np.ndarray
Boolean mask indicating which cells to update.
new_labels : np.ndarray
New label values for masked cells.
"""
# Convert to categorical if not already
if not isinstance(self.adata.obs[obs_label].dtype, pd.CategoricalDtype):
self.adata.obs[obs_label] = self.adata.obs[obs_label].astype('category')
# Add new categories if they don't exist
existing_categories = self.adata.obs[obs_label].cat.categories
new_categories = set(new_labels) - set(existing_categories)
if new_categories:
self.adata.obs[obs_label] = self.adata.obs[obs_label].cat.add_categories(
list(new_categories)
)
# Update values for masked cells
self.adata.obs.loc[mask, obs_label] = new_labels
def _store_refinement_operation(
self,
operation_name: str,
internal_data: _SingleThresholdingEventModel,
obs_label: str,
value_to_refine: str,
operation_type: str,
layer: Optional[str],
cell_counts_after_operation: Dict[str, int],
) -> None:
"""
Store operation metadata in .uns.
Parameters
----------
operation_name : str
Name for this operation.
internal_data : _SingleThresholdingEventModel
Pydantic model with thresholding data.
obs_label : str
Which obs column was modified.
value_to_refine : str
Which label value was refined.
operation_type : str
'refinement' or 'refinement_manual'.
layer : Optional[str]
Which layer was used (None = .X).
cell_counts_after_operation : Dict[str, int]
Dictionary of label counts immediately after operation.
"""
# Convert to dict for storage
operation_dict = internal_data.model_dump()
# Add sequential-specific metadata
operation_dict['operation_type'] = operation_type
operation_dict['parent_operation'] = obs_label
operation_dict['refined_from_labels'] = [value_to_refine]
operation_dict['layer'] = layer
operation_dict['cell_counts_after_operation'] = cell_counts_after_operation
# Store in .uns
self.adata.uns[self.thresholding_events_key][operation_name] = operation_dict
[docs]
def return_adata(self) -> ad.AnnData:
"""
Return the modified AnnData object.
Returns
-------
ad.AnnData
Modified AnnData object with all operations applied.
Examples
--------
::
seq_gmm = SequentialGMM(adata)
seq_gmm.threshold_entire_dataset(...)
seq_gmm.refine_labels_with_gmm(...)
adata_modified = seq_gmm.return_adata()
"""
return self.adata
def _get_descendant_labels(
self,
operation_name: str,
original_labels: List[str]
) -> List[str]:
"""
Get all labels that descended from an operation's labels.
This handles cases where subsequent operations refined labels from this operation.
For example, if 'separate_M_phase' created ['M', 'G1/S/G2'], and a later
operation refined 'G1/S/G2' into ['G1', 'S', 'G2'], this returns
['M', 'G1', 'S', 'G2'].
Parameters
----------
operation_name : str
Name of the operation.
original_labels : List[str]
The labels created by this operation.
Returns
-------
List[str]
List of all current labels that descended from original_labels.
"""
all_labels = set(original_labels)
# Check all subsequent operations to see if they refined any of our labels
metadata = self.adata.uns[self.thresholding_events_key]
for op_name, op_data in metadata.items():
# Skip the current operation and operations before it
if op_name == operation_name:
continue
# Check if this operation refined one of our labels
if 'refined_from_labels' in op_data and op_data['refined_from_labels']:
refined_from = op_data['refined_from_labels'][0] # Should be single label
if refined_from in all_labels:
# Replace the parent label with child labels
all_labels.discard(refined_from)
all_labels.update(op_data['ordered_gmm_labels'])
return list(all_labels)
[docs]
def plot_feature_distribution_exploratory(
self,
feature: str,
obs_label: Optional[str] = None,
value_to_subset: Optional[str] = None,
layer: Optional[str] = None,
hist_kwargs: Optional[Dict] = None,
ax: Optional[Axes] = None,
x_axis_limits: Optional[tuple] = None,
) -> Axes:
"""
Plot histogram of a feature distribution for exploratory analysis.
This method allows you to visualize feature distributions WITHOUT running
any thresholding, so you can explore your data and decide on manual thresholds
or the number of components to use for GMM.
Parameters
----------
feature : str
Feature name to plot (must exist in adata.var_names).
obs_label : Optional[str], optional
Obs column to use for subsetting. If provided with value_to_subset,
only plots cells with that label value. If None, plots all cells.
Defaults to None.
value_to_subset : Optional[str], optional
Specific label value to plot. Requires obs_label to be specified.
If None, plots all cells (or all cells in obs_label if provided).
Defaults to None.
layer : Optional[str], optional
Layer to use for data. If None, uses adata.X. Defaults to None.
hist_kwargs : Optional[Dict], optional
Keyword arguments for plt.hist().
Defaults to {'bins': 50, 'color': 'black', 'alpha': 0.7}.
ax : Optional[plt.Axes], optional
Matplotlib axes to plot on. If None, uses current axes. Defaults to None.
x_axis_limits : Optional[tuple], optional
(min, max) for x-axis. Use None for data-driven limits. Defaults to None.
Returns
-------
Axes
The matplotlib axes object.
Raises
------
ValueError
If value_to_subset is provided without obs_label.
KeyError
If obs_label doesn't exist in adata.obs.
ValueError
If value_to_subset is not present in adata.obs[obs_label].
Examples
--------
Explore entire dataset::
seq_gmm.plot_feature_distribution_exploratory(
feature='Int_Intg_DNA_nuc',
hist_kwargs={'bins': 30, 'color': 'steelblue'},
x_axis_limits=(5, 15)
)
plt.title('DNA Content Distribution - All Cells')
plt.show()
Explore specific subset::
seq_gmm.plot_feature_distribution_exploratory(
feature='Int_Intg_DNA_nuc',
obs_label='cell_cycle_phase',
value_to_subset='G1/S/G2',
hist_kwargs={'bins': 30, 'color': 'steelblue'},
x_axis_limits=(5, 15)
)
plt.title('DNA Content in G1/S/G2 Cells - Exploratory')
plt.show()
"""
if ax is None:
ax = plt.gca()
# Validate parameters
if value_to_subset is not None and obs_label is None:
raise ValueError(
"obs_label must be provided when value_to_subset is specified."
)
# Get subset of adata if value_to_subset provided
if value_to_subset is not None:
mask, _ = self._partition_data_by_label(
obs_label=obs_label,
value_to_refine=value_to_subset,
feature=feature,
layer=layer,
)
adata_subset = self.adata[mask, :]
else:
adata_subset = self.adata
# Use base class histogram plotting method
ax = super()._plot_hist_base(
adata=adata_subset,
feature=feature,
layer=layer,
hist_kwargs=hist_kwargs,
ax=ax,
x_axis_limits=x_axis_limits,
)
return ax
[docs]
def plot_feature_strip_plot_exploratory(
self,
feature: str,
obs_label: Optional[str] = None,
value_to_subset: Optional[str] = None,
layer: Optional[str] = None,
hist_kwargs: Optional[Dict] = None,
strip_plot_kwargs: Optional[Dict] = None,
scatter_density: bool = True,
x_axis_limits: Optional[tuple] = None,
vmax: Optional[Union[int, float]] = None,
) -> tuple:
"""
Plot strip plot + histogram for exploratory analysis.
Similar to plot_strip_plot_histogram_with_decision_boundaries() but WITHOUT
decision boundaries, for exploring data before running threshold operations.
This method allows you to visualize feature distributions WITHOUT running
any thresholding, so you can explore your data and decide on manual thresholds
or the number of components to use for GMM.
Parameters
----------
feature : str
Feature name to plot (must exist in adata.var_names).
obs_label : Optional[str], optional
Obs column to use for subsetting. If provided with value_to_subset,
only plots cells with that label value. If None, plots all cells.
Defaults to None.
value_to_subset : Optional[str], optional
Specific label value to plot. Requires obs_label to be specified.
If None, plots all cells (or all cells in obs_label if provided).
Defaults to None.
layer : Optional[str], optional
Layer to use for data. If None, uses adata.X. Defaults to None.
hist_kwargs : Optional[Dict], optional
Keyword arguments for histogram. Defaults to None.
strip_plot_kwargs : Optional[Dict], optional
Keyword arguments for strip plot. Only used when scatter_density=False.
Defaults to None.
scatter_density : bool, optional
If True, uses density-based coloring. If False, uses uniform scatter plot.
Defaults to True.
x_axis_limits : Optional[tuple], optional
(min, max) for x-axis. Use None for data-driven limits. Defaults to None.
vmax : Optional[Union[int, float]], optional
Maximum density value for colormap. Only used when scatter_density=True.
If None, auto-calculated. Defaults to None.
Returns
-------
tuple
(fig, (ax_strip, ax_hist)) - Figure and axes objects.
Raises
------
ValueError
If value_to_subset is provided without obs_label.
KeyError
If obs_label doesn't exist in adata.obs.
ValueError
If value_to_subset is not present in adata.obs[obs_label].
Examples
--------
Explore entire dataset::
fig, (ax_strip, ax_hist) = seq_gmm.plot_feature_strip_plot_exploratory(
feature='Int_Intg_DNA_nuc',
scatter_density=True,
x_axis_limits=(5, 15)
)
plt.suptitle('DNA Content Distribution - All Cells')
plt.show()
Explore specific subset::
fig, (ax_strip, ax_hist) = seq_gmm.plot_feature_strip_plot_exploratory(
feature='Int_Intg_DNA_nuc',
obs_label='cell_cycle_phase',
value_to_subset='G1/S/G2',
scatter_density=True,
x_axis_limits=(5, 15)
)
plt.suptitle('DNA Content Distribution - G1/S/G2 Cells')
plt.show()
"""
# Validate parameters
if value_to_subset is not None and obs_label is None:
raise ValueError(
"obs_label must be provided when value_to_subset is specified."
)
# Get subset of adata if value_to_subset provided
if value_to_subset is not None:
mask, _ = self._partition_data_by_label(
obs_label=obs_label,
value_to_refine=value_to_subset,
feature=feature,
layer=layer,
)
adata_subset = self.adata[mask, :]
else:
adata_subset = self.adata
# Call base method to create strip plot and histogram (without decision boundaries)
fig, ax_strip, ax_hist = super()._plot_strip_plot_base(
adata=adata_subset,
feature=feature,
layer=layer,
obs_label=None, # No labels for exploratory
ordered_labels=None, # No labels for exploratory
scatter_density=scatter_density,
y_axis_limits=x_axis_limits, # Note: x_axis becomes y_axis in vertical plot
hist_kwargs=hist_kwargs,
strip_plot_kwargs=strip_plot_kwargs,
cmap=mpl.colormaps['plasma'],
vmax=vmax,
)
return fig, (ax_strip, ax_hist)
[docs]
def plot_hist_distribution_with_boundaries(
self,
operation_name: str,
num_std: int = 5,
title: Optional[str] = None,
hist_kwargs: Optional[Dict] = None,
cmap: Optional[mpl_colors.Colormap] = None,
ax: Optional[Axes] = None,
x_axis_limits: Optional[tuple] = None,
resolution: int = 1000,
save_path: Optional[Union[str, Path]] = None,
) -> Axes:
"""
Plot histogram with boundaries for a specific operation.
Parameters
----------
operation_name : str
Name of operation to plot (from .uns keys).
num_std : int, optional
Number of standard deviations for GMM plotting. Defaults to 5.
title : Optional[str], optional
Plot title. If not provided, defaults to feature name.
Pass empty string '' to suppress title. Defaults to None.
hist_kwargs : Optional[Dict], optional
Kwargs for histogram. Defaults to None.
cmap : plt.cm.ScalarMappable, optional
Colormap. Defaults to 'rainbow'.
ax : plt.Axes, optional
Axes to plot on. If None, creates new. Defaults to None.
x_axis_limits : Optional[tuple], optional
X-axis limits (min, max). Defaults to None.
resolution : int, optional
Resolution for plotting. Defaults to 1000.
save_path : Optional[Union[str, Path]], optional
Path to save the figure. Parent directory must exist. Defaults to None.
Returns
-------
Axes
The matplotlib axes object. Call plt.show() to display it.
Raises
------
KeyError
If operation_name not found in .uns.
ValueError
If operation has no decision boundaries.
ValueError
If resolution <= 0 or <= n_components.
FileNotFoundError
If save_path parent directory doesn't exist.
Examples
--------
::
ax = seq_gmm.plot_hist_distribution_with_boundaries('Plk1_refinement')
plt.show()
"""
# Validate save path before generating the figure
save_path = _validate_save_path(save_path)
# Set default colormap if not provided
if cmap is None:
cmap = plt.get_cmap('rainbow')
# Validate operation_name exists
if operation_name not in self.adata.uns[self.thresholding_events_key]:
raise KeyError(
f"Operation '{operation_name}' not found in "
f"adata.uns['{self.thresholding_events_key}']. "
f"Available operations: {list(self.adata.uns[self.thresholding_events_key].keys())}"
)
# Load operation data from .uns
op_data = self.adata.uns[self.thresholding_events_key][operation_name]
internal_data = _SingleThresholdingEventModel(**op_data)
feature = op_data['feature_name']
layer = op_data.get('layer', None)
obs_label = op_data['gmm_obs_label']
ordered_labels = op_data['ordered_gmm_labels']
# Validate decision boundaries exist
if internal_data.decision_boundaries is None:
raise ValueError(
f"Decision boundaries not found for operation '{operation_name}'. "
"This should not happen."
)
# Validate resolution
if resolution <= 0:
raise ValueError("Resolution must be a positive integer.")
if internal_data.gmm_info is not None:
if resolution <= internal_data.gmm_info.n_components:
raise ValueError(
"Resolution must be greater than the number of GMM components."
)
# Filter adata to only include cells with labels from this operation
# For refinement operations, we need to include all descendant labels
# since subsequent operations may have further refined the labels
if 'refined_from_labels' in op_data and op_data['refined_from_labels']:
# This was a refinement operation - find all labels that descended from it
labels_to_plot = self._get_descendant_labels(operation_name, ordered_labels)
else:
# This was an initial threshold operation - use ordered_labels directly
labels_to_plot = ordered_labels
mask = self.adata.obs[obs_label].isin(labels_to_plot)
adata_subset = self.adata[mask, :]
# Call base class plotting method with explicit parameters
ax = super()._plot_hist_base(
adata=adata_subset,
feature=feature,
layer=layer,
hist_kwargs=hist_kwargs,
ax=ax,
x_axis_limits=x_axis_limits,
)
# Plot GMM components if this was GMM-based (not manual)
if internal_data.gmm_info is not None:
ax = super()._plot_gmm_components(
ax=ax,
adata=self.adata,
feature=feature,
internal_data=internal_data,
num_std=num_std,
resolution=resolution,
cmap=cmap
)
# Plot decision boundaries
ax = super()._plot_vertical_linear_decision_boundaries(
ax=ax,
internal_data=internal_data,
resolution=resolution,
cmap=cmap
)
# Plot legend
super()._plot_sample_catergory_legend(
ax=ax,
internal_data=internal_data,
cmap=cmap,
legend_kwargs=None
)
# Add title (default to feature name, allow override or suppression)
if title is None:
title = feature
if title: # Only add title if not empty string
ax.set_title(title)
# Save figure if save_path is provided
if save_path:
plt.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"Figure saved to: {save_path}")
return ax
[docs]
def plot_strip_plot_histogram_with_decision_boundaries(
self,
operation_name: str,
cmap: Optional[mpl_colors.Colormap] = None,
y_axis_limits: Optional[Tuple[float, float]] = None,
resolution: int = 1000,
scatter_density: bool = True,
vmax: Optional[Union[int, float]] = None,
hist_kwargs: Optional[Dict] = None,
strip_plot_kwargs: Optional[Dict] = None,
title: Optional[str] = None,
) -> Figure:
"""
Plot 1D strip plot with histogram and decision boundaries for a specific operation.
This method wraps the base class implementation to provide visualization
for sequential thresholding operations. It creates a density strip plot
(or label-colored scatter) alongside a horizontal histogram showing the
distribution and decision boundaries for the specified operation.
Parameters
----------
operation_name : str
Name of operation to plot (from .uns keys).
cmap : plt.cm.ScalarMappable, optional
Colormap for density or labels. Defaults to mpl.colormaps['plasma'].
y_axis_limits : Optional[Tuple[float, float]], optional
Y-axis limits (min, max). If None, uses data min/max. Defaults to None.
resolution : int, optional
Resolution for boundary plotting. Defaults to 1000.
scatter_density : bool, optional
If True, color by density; if False, color by labels. Defaults to True.
vmax : Optional[Union[int, float]], optional
Maximum density value for colormap. If None, auto-calculated. Defaults to None.
hist_kwargs : Optional[Dict], optional
Kwargs for histogram (bins, color, etc.). Defaults to None.
strip_plot_kwargs : Optional[Dict], optional
Kwargs for strip plot scatter (e.g., s, alpha, marker).
Only used when scatter_density=False. Defaults to None.
title : Optional[str], optional
Title for the plot. If not provided, defaults to feature name.
Pass empty string '' to suppress title. Defaults to None.
Returns
-------
Figure
The matplotlib figure object. Call plt.show() to display it.
Raises
------
KeyError
If operation_name not found in .uns.
ValueError
If operation has no decision boundaries.
Examples
--------
Basic usage with label-colored scatter::
fig = seq_gmm.plot_strip_plot_histogram_with_decision_boundaries(
operation_name='separate_M_phase',
scatter_density=False
)
plt.show()
Custom title::
fig = seq_gmm.plot_strip_plot_histogram_with_decision_boundaries(
operation_name='separate_M_phase',
scatter_density=False,
title='M Phase Separation'
)
plt.show()
Customize strip plot appearance::
fig = seq_gmm.plot_strip_plot_histogram_with_decision_boundaries(
operation_name='separate_M_phase',
scatter_density=False,
strip_plot_kwargs={'s': 5, 'alpha': 0.8, 'marker': 'o'}
)
plt.show()
"""
# Validate operation_name exists
if operation_name not in self.adata.uns[self.thresholding_events_key]:
raise KeyError(
f"Operation '{operation_name}' not found in "
f"adata.uns['{self.thresholding_events_key}']. "
f"Available operations: {list(self.adata.uns[self.thresholding_events_key].keys())}"
)
# Load operation data from .uns
op_data = self.adata.uns[self.thresholding_events_key][operation_name]
internal_data = _SingleThresholdingEventModel(**op_data)
feature = op_data['feature_name']
layer = op_data.get('layer', None)
obs_label = op_data['gmm_obs_label']
ordered_labels = op_data['ordered_gmm_labels']
# Validate decision boundaries exist
if internal_data.decision_boundaries is None:
raise ValueError(
f"Decision boundaries not found for operation '{operation_name}'."
)
# Filter adata to only include cells with labels from this operation
# For refinement operations, we need to include all descendant labels
# since subsequent operations may have further refined the labels
if 'refined_from_labels' in op_data and op_data['refined_from_labels']:
# This was a refinement operation - find all labels that descended from it
labels_to_plot = self._get_descendant_labels(operation_name, ordered_labels)
else:
# This was an initial threshold operation - use ordered_labels directly
labels_to_plot = ordered_labels
mask = self.adata.obs[obs_label].isin(labels_to_plot)
adata_subset = self.adata[mask, :]
# Set default colormap if not provided
if cmap is None:
cmap = mpl.colormaps['plasma']
# Call base class implementation and return figure
return super()._plot_strip_plot_histogram_with_decision_boundaries(
adata=adata_subset,
feature=feature,
layer=layer,
obs_label=obs_label,
ordered_labels=ordered_labels,
internal_data=internal_data,
cmap=cmap,
y_axis_limits=y_axis_limits,
resolution=resolution,
scatter_density=scatter_density,
vmax=vmax,
hist_kwargs=hist_kwargs,
strip_plot_kwargs=strip_plot_kwargs,
title=title,
)