Celcomen

Project home page here.

Main workflow classes

The celcomen class

class celcomen.models.celcomen.celcomen(input_dim, output_dim, n_neighbors, seed=0)

Bases: Module

A k-hop Graph Neural Network model for disentangling inter- and intra-cellular gene regulation, and then leveraging to predict spatial counterfactuals.

Parameters

input_dimint

Dimensionality of the input features (gene expression data).

output_dimint

Dimensionality of the output features.

n_neighborsint

The number of neighbours used in the spatial graph to model cell-cell interactions.

seedint, optional

Seed for random number generation to ensure reproducibility. Default is 0.

Attributes

conv1GCNConv

A graph convolutional layer that models gene-to-gene interactions (G2G).

lintorch.nn.Linear

A linear layer that models intracellular gene regulation.

n_neighborsint

The number of neighbours for spatial graph construction.

gextorch.nn.Parameter or None

Stores the gene expression matrix used for the forward pass. Set to None initially.

Methods

set_g2g(g2g)

Sets the gene-to-gene (G2G) interaction matrix artificially.

set_g2g_intra(g2g_intra)

Sets the intracellular regulation matrix artificially.

set_gex(gex)

Sets the gene expression matrix artificially.

forward(edge_index, batch)

Forward pass to compute the gene-to-gene and intracellular messages, and the log partition function estimate.

log_Z_mft(edge_index, batch)

Computes the Mean Field Theory (MFT) approximation to the partition function.

z_interaction(num_spots, g)

Provides an approximation for the interaction term in the partition function to prevent numerical instability due to exploding exponentials.

Examples

>>> model = celcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42)
>>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
>>> batch = torch.tensor([0, 1], dtype=torch.long)
>>> model.set_gex(torch.randn(100, 1000))
>>> msg, msg_intra, log_z_mft = model(edge_index, batch)
>>> print(log_z_mft)
forward(edge_index, batch)

Forward pass for the model, computing gene-to-gene and intracellular messages, and estimating the log partition function using Mean Field Theory (MFT).

Parameters

edge_indextorch.Tensor

Tensor representing the graph edges (connectivity between nodes/cells).

batchtorch.Tensor

Tensor representing the batch of data.

Returns

msgtorch.Tensor

The message propagated between cells based on gene-to-gene interactions.

msg_intratorch.Tensor

The message based on intracellular gene regulation.

log_z_mfttorch.Tensor

The Mean Field Theory approximation to the log partition function.

log_Z_mft(edge_index, batch)

Computes the Mean Field Theory (MFT) approximation to the partition function, which estimates the likelihood of gene expression states in the dataset.

Parameters

edge_indextorch.Tensor

Tensor representing the graph edges (connectivity between nodes/cells).

batchtorch.Tensor

Tensor representing the batch of data.

Returns

log_z_mfttorch.Tensor

The log partition function estimated using Mean Field Theory (MFT).

set_g2g(g2g)

Artificially sets the gene-to-gene (G2G) interaction matrix.

Parameters

g2gtorch.Tensor

A matrix representing gene-to-gene interactions to be used for graph convolution.

set_g2g_intra(g2g_intra)

Artificially sets the intracellular gene regulation matrix.

Parameters

g2g_intratorch.Tensor

A matrix representing intracellular gene regulation interactions.

set_gex(gex)

Sets the gene expression matrix to be used during the forward pass.

Parameters

gextorch.Tensor

A matrix representing the gene expression of the cells.

z_interaction(num_spots, g)

Avoids exploding exponentials in the partition function approximation by returning an approximate interaction term.

Parameters

num_spotsint

Number of spots (cells) in the dataset.

gtorch.Tensor

Norm of the sum of mean gene expressions weighted by gene-to-gene interactions.

Returns

z_interactiontorch.Tensor

The approximated interaction term for the partition function.

The simcomen class

class celcomen.models.simcomen.simcomen(input_dim, output_dim, n_neighbors, seed=0)

Bases: Module

A k-hop Graph Neural Network model for predicting spatial counterfactuals, such as localised perturbations.

Parameters

input_dimint

The dimensionality of the input gene expression features.

output_dimint

The dimensionality of the output features after processing through graph convolution and linear layers.

n_neighborsint

The number of neighbors to use in constructing the k-nearest neighbor graph.

seedint, optional

Random seed for reproducibility, default is 0.

Attributes

conv1GCNConv

Graph convolutional layer for gene-to-gene (G2G) interactions.

lintorch.nn.Linear

Linear layer for intracellular gene regulation.

n_neighborsint

Number of spatial neighbors used for constructing the graph.

sphextorch.nn.Parameter or None

Spherical gene expression matrix, set via set_sphex.

gextorch.nn.Parameter or None

Gene expression matrix, calculated from the spherical expression matrix.

output_dimint

Output dimensionality of the model.

Methods

set_g2g(g2g)

Sets the gene-to-gene (G2G) interaction matrix artificially.

set_g2g_intra(g2g_intra)

Sets the intracellular gene regulation matrix artificially.

set_sphex(sphex)

Sets the spherical gene expression matrix artificially.

forward(edge_index, batch)

Forward pass of the model, calculating messages from gene-to-gene interactions, intracellular interactions, and the log partition function (log(Z_mft)).

log_Z_mft(edge_index, batch)

Computes the Mean Field Theory (MFT) approximation to the partition function for the current gene expressions.

z_interaction(num_spots, g)

Calculates the interaction term for the partition function while avoiding numerical instability.

calc_gex(sphex)

Converts the spherical gene expression matrix into a regular gene expression matrix.

calc_sphex(gex)

Converts the regular gene expression matrix into a spherical gene expression matrix.

get_pos(n_x, n_y)

Generates a 2D hexagonal grid of positions for spatial modeling.

normalize_g2g(g2g)

Symmetrizes and normalizes the gene-to-gene interaction matrix.

Examples

>>> model = simcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42)
>>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
>>> batch = torch.tensor([0, 1], dtype=torch.long)
>>> model.set_sphex(torch.randn(100, 1000))
>>> msg, msg_intra, log_z_mft = model(edge_index, batch)
>>> print(log_z_mft)
calc_gex(sphex)

Converts the spherical expression matrix into a regular gene expression matrix.

Parameters

sphextorch.Tensor

The spherical gene expression matrix.

Returns

gextorch.Tensor

The converted regular gene expression matrix.

calc_sphex(gex)

Converts the regular gene expression matrix into a spherical gene expression matrix.

Parameters

gextorch.Tensor

The regular gene expression matrix.

Returns

sphextorch.Tensor

The converted spherical gene expression matrix.

forward(edge_index, batch)

Forward pass of the model, calculates the messages between nodes using gene-to-gene interactions and intracellular gene regulation. Also calculates the log partition function using Mean Field Theory.

Parameters

edge_indextorch.Tensor

Tensor representing the graph edges (connections between nodes/cells).

batchtorch.Tensor

Tensor representing the batch of data.

Returns

msgtorch.Tensor

Message passed between nodes based on gene-to-gene interactions.

msg_intratorch.Tensor

Message passed within nodes based on intracellular gene regulation.

log_z_mfttorch.Tensor

Mean Field Theory approximation of the log partition function.

get_pos(n_y)

Generates a 2D hexagonal grid of positions for spatial modelling.

Parameters

n_xint

Number of positions along the x-axis.

n_yint

Number of positions along the y-axis.

Returns

posnumpy.ndarray

Array of 2D positions for the grid.

log_Z_mft(edge_index, batch)

Computes the Mean Field Theory (MFT) approximation of the partition function. This function assumes that gene expression values are close to their mean across the spatial slide.

Parameters

edge_indextorch.Tensor

Tensor representing the graph edges (connections between nodes/cells).

batchtorch.Tensor

Tensor representing the batch of data.

Returns

log_z_mfttorch.Tensor

Mean Field Theory approximation of the log partition function.

normalize_g2g()

Symmetrizes and normalizes the gene-to-gene interaction matrix.

Parameters

g2gnumpy.ndarray

The gene-to-gene interaction matrix.

Returns

g2gnumpy.ndarray

The normalized and symmetrized gene-to-gene interaction matrix.

set_g2g(g2g)

Artificially sets the gene-to-gene (G2G) interaction matrix.

Parameters

g2gtorch.Tensor

A tensor representing gene-to-gene interactions to be used in the graph convolution.

set_g2g_intra(g2g_intra)

Artificially sets the intracellular regulation matrix.

Parameters

g2g_intratorch.Tensor

A tensor representing intracellular gene regulation interactions.

set_sphex(sphex)

Sets the spherical gene expression matrix for the forward pass.

Parameters

sphextorch.Tensor

A tensor representing the spherical expression matrix.

z_interaction(num_spots, g)

Calculates the interaction term for the partition function approximation, avoiding exploding exponentials.

Parameters

num_spotsint

Number of spots (cells) in the dataset.

gtorch.Tensor

Norm of the sum of mean gene expressions weighted by gene-to-gene interactions.

Returns

z_interactiontorch.Tensor

Approximated interaction term for the partition function.

Utility functions

celcomen.utils.helpers.calc_gex(sphex)

Converts a spherical gene expression matrix into a standard gene expression matrix.

celcomen.datareaders.datareader.get_dataset_loaders(...)

Prepares and returns PyTorch Geometric DataLoader from a single-cell spatial transcriptomics dataset.