Celcomen
Project home page here.
Main workflow classes
The celcomen class
- class celcomen.models.celcomen.celcomen(input_dim, output_dim, n_neighbors, seed=0)
Bases:
ModuleA k-hop Graph Neural Network model for disentangling inter- and intra-cellular gene regulation, and then leveraging to predict spatial counterfactuals.
Parameters
- input_dimint
Dimensionality of the input features (gene expression data).
- output_dimint
Dimensionality of the output features.
- n_neighborsint
The number of neighbours used in the spatial graph to model cell-cell interactions.
- seedint, optional
Seed for random number generation to ensure reproducibility. Default is 0.
Attributes
- conv1GCNConv
A graph convolutional layer that models gene-to-gene interactions (G2G).
- lintorch.nn.Linear
A linear layer that models intracellular gene regulation.
- n_neighborsint
The number of neighbours for spatial graph construction.
- gextorch.nn.Parameter or None
Stores the gene expression matrix used for the forward pass. Set to None initially.
Methods
- set_g2g(g2g)
Sets the gene-to-gene (G2G) interaction matrix artificially.
- set_g2g_intra(g2g_intra)
Sets the intracellular regulation matrix artificially.
- set_gex(gex)
Sets the gene expression matrix artificially.
- forward(edge_index, batch)
Forward pass to compute the gene-to-gene and intracellular messages, and the log partition function estimate.
- log_Z_mft(edge_index, batch)
Computes the Mean Field Theory (MFT) approximation to the partition function.
- z_interaction(num_spots, g)
Provides an approximation for the interaction term in the partition function to prevent numerical instability due to exploding exponentials.
Examples
>>> model = celcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42) >>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) >>> batch = torch.tensor([0, 1], dtype=torch.long) >>> model.set_gex(torch.randn(100, 1000)) >>> msg, msg_intra, log_z_mft = model(edge_index, batch) >>> print(log_z_mft)
- forward(edge_index, batch)
Forward pass for the model, computing gene-to-gene and intracellular messages, and estimating the log partition function using Mean Field Theory (MFT).
Parameters
- edge_indextorch.Tensor
Tensor representing the graph edges (connectivity between nodes/cells).
- batchtorch.Tensor
Tensor representing the batch of data.
Returns
- msgtorch.Tensor
The message propagated between cells based on gene-to-gene interactions.
- msg_intratorch.Tensor
The message based on intracellular gene regulation.
- log_z_mfttorch.Tensor
The Mean Field Theory approximation to the log partition function.
- log_Z_mft(edge_index, batch)
Computes the Mean Field Theory (MFT) approximation to the partition function, which estimates the likelihood of gene expression states in the dataset.
Parameters
- edge_indextorch.Tensor
Tensor representing the graph edges (connectivity between nodes/cells).
- batchtorch.Tensor
Tensor representing the batch of data.
Returns
- log_z_mfttorch.Tensor
The log partition function estimated using Mean Field Theory (MFT).
- set_g2g(g2g)
Artificially sets the gene-to-gene (G2G) interaction matrix.
Parameters
- g2gtorch.Tensor
A matrix representing gene-to-gene interactions to be used for graph convolution.
- set_g2g_intra(g2g_intra)
Artificially sets the intracellular gene regulation matrix.
Parameters
- g2g_intratorch.Tensor
A matrix representing intracellular gene regulation interactions.
- set_gex(gex)
Sets the gene expression matrix to be used during the forward pass.
Parameters
- gextorch.Tensor
A matrix representing the gene expression of the cells.
- z_interaction(num_spots, g)
Avoids exploding exponentials in the partition function approximation by returning an approximate interaction term.
Parameters
- num_spotsint
Number of spots (cells) in the dataset.
- gtorch.Tensor
Norm of the sum of mean gene expressions weighted by gene-to-gene interactions.
Returns
- z_interactiontorch.Tensor
The approximated interaction term for the partition function.
The simcomen class
- class celcomen.models.simcomen.simcomen(input_dim, output_dim, n_neighbors, seed=0)
Bases:
ModuleA k-hop Graph Neural Network model for predicting spatial counterfactuals, such as localised perturbations.
Parameters
- input_dimint
The dimensionality of the input gene expression features.
- output_dimint
The dimensionality of the output features after processing through graph convolution and linear layers.
- n_neighborsint
The number of neighbors to use in constructing the k-nearest neighbor graph.
- seedint, optional
Random seed for reproducibility, default is 0.
Attributes
- conv1GCNConv
Graph convolutional layer for gene-to-gene (G2G) interactions.
- lintorch.nn.Linear
Linear layer for intracellular gene regulation.
- n_neighborsint
Number of spatial neighbors used for constructing the graph.
- sphextorch.nn.Parameter or None
Spherical gene expression matrix, set via set_sphex.
- gextorch.nn.Parameter or None
Gene expression matrix, calculated from the spherical expression matrix.
- output_dimint
Output dimensionality of the model.
Methods
- set_g2g(g2g)
Sets the gene-to-gene (G2G) interaction matrix artificially.
- set_g2g_intra(g2g_intra)
Sets the intracellular gene regulation matrix artificially.
- set_sphex(sphex)
Sets the spherical gene expression matrix artificially.
- forward(edge_index, batch)
Forward pass of the model, calculating messages from gene-to-gene interactions, intracellular interactions, and the log partition function (log(Z_mft)).
- log_Z_mft(edge_index, batch)
Computes the Mean Field Theory (MFT) approximation to the partition function for the current gene expressions.
- z_interaction(num_spots, g)
Calculates the interaction term for the partition function while avoiding numerical instability.
- calc_gex(sphex)
Converts the spherical gene expression matrix into a regular gene expression matrix.
- calc_sphex(gex)
Converts the regular gene expression matrix into a spherical gene expression matrix.
- get_pos(n_x, n_y)
Generates a 2D hexagonal grid of positions for spatial modeling.
- normalize_g2g(g2g)
Symmetrizes and normalizes the gene-to-gene interaction matrix.
Examples
>>> model = simcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42) >>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) >>> batch = torch.tensor([0, 1], dtype=torch.long) >>> model.set_sphex(torch.randn(100, 1000)) >>> msg, msg_intra, log_z_mft = model(edge_index, batch) >>> print(log_z_mft)
- calc_gex(sphex)
Converts the spherical expression matrix into a regular gene expression matrix.
Parameters
- sphextorch.Tensor
The spherical gene expression matrix.
Returns
- gextorch.Tensor
The converted regular gene expression matrix.
- calc_sphex(gex)
Converts the regular gene expression matrix into a spherical gene expression matrix.
Parameters
- gextorch.Tensor
The regular gene expression matrix.
Returns
- sphextorch.Tensor
The converted spherical gene expression matrix.
- forward(edge_index, batch)
Forward pass of the model, calculates the messages between nodes using gene-to-gene interactions and intracellular gene regulation. Also calculates the log partition function using Mean Field Theory.
Parameters
- edge_indextorch.Tensor
Tensor representing the graph edges (connections between nodes/cells).
- batchtorch.Tensor
Tensor representing the batch of data.
Returns
- msgtorch.Tensor
Message passed between nodes based on gene-to-gene interactions.
- msg_intratorch.Tensor
Message passed within nodes based on intracellular gene regulation.
- log_z_mfttorch.Tensor
Mean Field Theory approximation of the log partition function.
- get_pos(n_y)
Generates a 2D hexagonal grid of positions for spatial modelling.
Parameters
- n_xint
Number of positions along the x-axis.
- n_yint
Number of positions along the y-axis.
Returns
- posnumpy.ndarray
Array of 2D positions for the grid.
- log_Z_mft(edge_index, batch)
Computes the Mean Field Theory (MFT) approximation of the partition function. This function assumes that gene expression values are close to their mean across the spatial slide.
Parameters
- edge_indextorch.Tensor
Tensor representing the graph edges (connections between nodes/cells).
- batchtorch.Tensor
Tensor representing the batch of data.
Returns
- log_z_mfttorch.Tensor
Mean Field Theory approximation of the log partition function.
- normalize_g2g()
Symmetrizes and normalizes the gene-to-gene interaction matrix.
Parameters
- g2gnumpy.ndarray
The gene-to-gene interaction matrix.
Returns
- g2gnumpy.ndarray
The normalized and symmetrized gene-to-gene interaction matrix.
- set_g2g(g2g)
Artificially sets the gene-to-gene (G2G) interaction matrix.
Parameters
- g2gtorch.Tensor
A tensor representing gene-to-gene interactions to be used in the graph convolution.
- set_g2g_intra(g2g_intra)
Artificially sets the intracellular regulation matrix.
Parameters
- g2g_intratorch.Tensor
A tensor representing intracellular gene regulation interactions.
- set_sphex(sphex)
Sets the spherical gene expression matrix for the forward pass.
Parameters
- sphextorch.Tensor
A tensor representing the spherical expression matrix.
- z_interaction(num_spots, g)
Calculates the interaction term for the partition function approximation, avoiding exploding exponentials.
Parameters
- num_spotsint
Number of spots (cells) in the dataset.
- gtorch.Tensor
Norm of the sum of mean gene expressions weighted by gene-to-gene interactions.
Returns
- z_interactiontorch.Tensor
Approximated interaction term for the partition function.
Utility functions
Converts a spherical gene expression matrix into a standard gene expression matrix. |
|
Prepares and returns PyTorch Geometric DataLoader from a single-cell spatial transcriptomics dataset. |