import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.colors import Colormap
import pandas as pd
from typing import Literal, Any
[docs]
class GeneSetActivityScoresReport:
"""Report object for storing GESSO geneset activity score results."""
def __init__(
self,
gas_df: pd.DataFrame,
locations_df: pd.DataFrame,
geneset_to_gene_contributions_df_dict: dict,
) -> None:
"""Initializes the GeneSetActivityScoresReport object.
Parameters
----------
gas_df : pd.DataFrame
geneset activity scores DataFrame. Should be of size (n_obs, n_genesets).
location_df : pd.DataFrame
Locations DataFrame. Should be of size (n_obs, 2).
geneset_to_gene_contributions_df_dict: dict
Dictionary of geneset to gene contribution DataFrames.
"""
self._gas_df = gas_df
self._location_df = locations_df
self._orig_spot_order = locations_df.index
self._geneset_to_gene_contributions_df_dict: dict[str, pd.DataFrame] = (
geneset_to_gene_contributions_df_dict
)
self._n_examples, self._n_genesets = gas_df.shape
[docs]
def gene_contributions_df(
self,
geneset: str,
sort_by: Literal["gene_contribution", "gene_name"] = "gene_contribution",
) -> pd.DataFrame:
"""Returns a gene contribution DataFrame with a single column (geneset name).
The index is the gene name.
Parameters
----------
geneset : str
geneset name.
sort_by : Literal["gene_contribution", "gene_name"]
Default: "gene_contribution". How to sort the DataFrame.
If "gene_contribution", sorts by the gene contribution weight (descending).
If "gene_name", sorts by the gene name (ascending).
Returns
-------
pd.DataFrame
"""
output = self._geneset_to_gene_contributions_df_dict[geneset]
if sort_by == "gene_contribution":
output = output.sort_values(by=geneset, ascending=False)
elif sort_by == "gene_name":
# the gene name is in the index
output = output.sort_index(ascending=True)
else:
raise ValueError(
f"Invalid sort_by value: {sort_by}. "
"Must be 'gene_contribution' or 'gene_name'."
)
return output
[docs]
def locations_df(self) -> pd.DataFrame:
"""Returns the locations DataFrame.
The index is the spot ID. The columns are "x" and "y".
Returns
-------
pd.DataFrame
"""
return self._location_df[["x", "y"]]
[docs]
def gas_df(self) -> pd.DataFrame:
"""Returns the geneset activity scores as a DataFrame.
The index is the spot ID. The columns are the geneset names.
Returns
-------
pd.DataFrame
"""
return self._gas_df.loc[self._orig_spot_order]
[docs]
def plot_gas_spatial_map(
self,
geneset: str,
size: int = 20,
cmap: Colormap | str = "viridis",
show_coords: bool = False,
figsize: tuple[float, float] = (5.0, 5.0),
ax: Axes | None = None,
) -> Figure:
"""Plots the geneset activity scores of a given geneset of interest
across all locations.
Parameters
----------
geneset : str
The name of the geneset to plot.
size : int
Default: 20. The size of the scatter points.
cmap : Colormap | None
Default: "viridis". The colormap to use for the scatter plot.
show_coords : bool
Default: False. If True, shows the coordinates of the points.
figsize : tuple[float, float]
Default: (5.0, 5.0) The size of the figure.
ax : plt.Axes | None
Default: None. If None, creates a new figure.
"""
if cmap is None:
cmap = "viridis"
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
plotting_df = self._location_df.join(self._gas_df[geneset])
cdata = plotting_df[geneset].to_numpy()
scatter = ax.scatter(
x=plotting_df["x"].to_numpy(),
y=plotting_df["y"].to_numpy(),
c=cdata,
s=size,
cmap=cmap,
vmin=cdata.min(),
vmax=cdata.max(),
)
if not show_coords:
ax.set_xticks([])
ax.set_yticks([])
ax.grid(False)
for spine in ax.spines.values():
spine.set_visible(False)
fig.colorbar(scatter, ax=ax, fraction=0.02, pad=0.01)
ax.set_title(f"GESSO Gene Set Activity Scores")
fig.tight_layout()
plt.close(fig)
return fig
[docs]
class PermutationTestReport:
"""Report object for storing GESSO permutation test results."""
def __init__(
self,
geneset: str,
permutation_test_df: pd.DataFrame,
):
"""Initializes the PermutationTestReport object.
Parameters
----------
geneset : str
The name of the geneset for which the permutation test was performed.
permuation_test_df : pd.DataFrame
DataFrame containing the results of the permutation test.
Should have columns: 'x', 'y', 'gas', 'p'
"""
self._geneset = geneset
self._permutation_test_df = permutation_test_df
[docs]
def plot_gas_spatial_map(
self,
size: int = 20,
cmap: Colormap | str = "viridis",
show_coords: bool = False,
figsize: tuple[float, float] = (5.0, 5.0),
ax: Axes | None = None,
) -> Figure:
"""Plots the gene set activity scores of the permutation
test across all locations.
Parameters
----------
size : int
Default: 20. The size of the scatter points.
cmap : Colormap | None
Default: "viridis". The colormap to use for the scatter plot.
show_coords : bool
Default: False. If True, shows the coordinates of the points.
figsize : tuple[float, float]
Default: (5.0, 5.0). The size of the figure.
ax : plt.Axes | None
Default: None. If None, creates a new figure.
"""
if cmap is None:
cmap = "viridis"
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
plotting_df = self._permutation_test_df
cdata = plotting_df["gas"].to_numpy()
scatter = ax.scatter(
x=plotting_df["x"],
y=plotting_df["y"],
c=cdata,
s=size,
cmap=cmap,
vmin=cdata.min(),
vmax=cdata.max(),
)
if not show_coords:
ax.set_xticks([])
ax.set_yticks([])
ax.grid(False)
for spine in ax.spines.values():
spine.set_visible(False)
fig.colorbar(scatter, ax=ax, fraction=0.02, pad=0.01)
ax.set_title(f"GESSO: Permutation Test Results for {self._geneset}")
fig.tight_layout()
plt.close(fig)
return fig
[docs]
def plot_pval_spatial_map(
self,
size: int = 20,
significance_threshold: float = 0.05,
significant_color: str | Any = "#800080", # dark purple hex code
not_significant_color: str | Any = "#D3D3D3",
show_coords: bool = False,
figsize: tuple[float, float] = (5.0, 5.0),
ax: Axes | None = None,
) -> Figure:
"""Plots the p-values of the permutation test across all locations.
Parameters
----------
size : int
Default: 20. The size of the scatter points.
significance_threshold : float
Default: 0.05. The threshold for significance.
significant_color : str
Default: "#800080". The color for significant points.
not_significant_color : str
Default: "#D3D3D3". The color for not significant points.
show_coords : bool
Default: False. If True, shows the coordinates of the points.
figsize : tuple[float, float]
Default: (5.0, 5.0). The size of the figure.
ax : plt.Axes | None
Default: None. If None, creates a new figure.
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
plotting_df = self._permutation_test_df
colors = [
significant_color if p < significance_threshold else not_significant_color
for p in plotting_df["p"]
]
ax.scatter(
x=plotting_df["x"],
y=plotting_df["y"],
c=colors,
s=size,
)
ax.set_title(f"GESSO: Spots with Elevated Activity")
if not show_coords:
ax.set_xticks([])
ax.set_yticks([])
ax.grid(False)
for spine in ax.spines.values():
spine.set_visible(False)
fig.tight_layout()
plt.close(fig)
return fig
[docs]
def gas_df(self) -> pd.DataFrame:
"""Returns the geneset activity scores DataFrame.
Returns
-------
pd.DataFrame
DataFrame containing the geneset activity scores.
"""
return self._permutation_test_df[["gas"]].rename(columns={"gas": self._geneset})
[docs]
def pval_df(self) -> pd.DataFrame:
"""Returns the p-values DataFrame.
Returns
-------
pd.DataFrame
DataFrame containing the p-values.
"""
return self._permutation_test_df[["p"]]
[docs]
def locations_df(self) -> pd.DataFrame:
"""Returns the locations DataFrame.
Returns
-------
pd.DataFrame
DataFrame containing the locations of the spots.
Contains two columns: 'x' and 'y'.
"""
return self._permutation_test_df[["x", "y"]]
[docs]
def htest_df(self) -> pd.DataFrame:
"""Returns the full permutation test DataFrame.
Returns
-------
pd.DataFrame
DataFrame containing the full permutation test results.
Contains four columns: 'x', 'y', 'gas', 'p'.
"""
return self._permutation_test_df