Skip to content

Plotting

quicat.pl.barplot(adata, groupby, order=None, color='blue', figsize=(10, 6), xlabel=None, ylabel='Percentage (%)', title=None, show=True, save=None, dpi=150, **kwargs)

Creates a bar plot showing the percentage of each category in a specified variable.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required
groupby str

Name of the variable (column in adata.obs) to plot.

required
order Optional[List[str]]

Specific order of the categories. If None, categories are ordered from highest to lowest percentage. Defaults to None.

None
color Union[str, Tuple[float, float, float]]

Color for the bars. Can be a color name, an RGB tuple, or a hex code. Defaults to 'blue'.

'blue'
figsize Tuple[float, float]

Size of the figure. Defaults to (10, 6).

(10, 6)
xlabel Optional[str]

Label for the x-axis. Defaults to None.

None
ylabel str

Label for the y-axis. Defaults to 'Percentage (%)'.

'Percentage (%)'
title Optional[str]

Title of the plot. Defaults to None.

None
show bool

If True, displays the plot. If False, returns the figure object. Defaults to True.

True
save Optional[str]

Path to save the figure. If None, the figure is not saved. Defaults to None.

None
dpi int

Resolution of the saved figure. Defaults to 150.

150
**kwargs

Additional keyword arguments to pass to sns.barplot.

{}

Returns:

Type Description
Optional[Figure]

Optional[plt.Figure]: The matplotlib Figure object if show is False, otherwise None.

Source code in src/quicat/plotting/_barplot.py
def barplot(
    adata: anndata.AnnData,
    groupby: str,
    order: Optional[List[str]] = None,
    color: Union[str, Tuple[float, float, float]] = "blue",
    figsize: Tuple[float, float] = (10, 6),
    xlabel: Optional[str] = None,
    ylabel: str = "Percentage (%)",
    title: Optional[str] = None,
    show: bool = True,
    save: Optional[str] = None,
    dpi: int = 150,
    **kwargs,
) -> Optional[plt.Figure]:
    """
    Creates a bar plot showing the percentage of each category in a specified variable.

    Parameters:
        adata (anndata.AnnData): The input AnnData object.
        groupby (str): Name of the variable (column in `adata.obs`) to plot.
        order (Optional[List[str]], optional): Specific order of the categories. If None,
            categories are ordered from highest to lowest percentage. Defaults to None.
        color (Union[str, Tuple[float, float, float]], optional): Color for the bars. Can be a color name, an RGB tuple, or a hex code.
            Defaults to 'blue'.
        figsize (Tuple[float, float], optional): Size of the figure. Defaults to (10, 6).
        xlabel (Optional[str], optional): Label for the x-axis. Defaults to None.
        ylabel (str, optional): Label for the y-axis. Defaults to 'Percentage (%)'.
        title (Optional[str], optional): Title of the plot. Defaults to None.
        show (bool, optional): If True, displays the plot. If False, returns the figure object. Defaults to True.
        save (Optional[str], optional): Path to save the figure. If None, the figure is not saved.
            Defaults to None.
        dpi (int, optional): Resolution of the saved figure. Defaults to 150.
        **kwargs: Additional keyword arguments to pass to `sns.barplot`.

    Returns:
        Optional[plt.Figure]: The matplotlib Figure object if `show` is False, otherwise None.
    """
    data = adata.obs[groupby].value_counts().reset_index()
    data.columns = [groupby, "count"]

    data["percentage"] = 100 * data["count"] / data["count"].sum()

    if order is not None:
        data = data.set_index(groupby).loc[order].reset_index()
    else:
        data = data.sort_values(by="percentage", ascending=False)

    fig, ax = plt.subplots(figsize=figsize)
    sns.barplot(data=data, x=groupby, y="percentage", color=color, ax=ax, **kwargs)

    ax.set_xlabel(xlabel if xlabel else groupby)
    ax.set_ylabel(ylabel)
    if title:
        ax.set_title(title)

    plt.tight_layout()

    if save:
        fig.savefig(save, bbox_inches="tight", dpi=dpi)

    if show:
        plt.show()
        return None
    else:
        return fig

quicat.pl.stacked_barplot(adata, groupby, obs_key, palette=None, cmap=None, figsize=(10, 6), xlabel=None, ylabel='Percentage (%)', title=None, legend_title=None, show=True, save=None, dpi=150, **kwargs)

Creates a stacked bar plot showing the percentage distribution of one categorical variable over another.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object containing observations.

required
groupby str

The categorical variable in adata.obs used for grouping on the x-axis.

required
obs_key str

The categorical variable in adata.obs to stack.

required
palette Union[str, List[str]]

Seaborn palette or list of colors. Defaults to None.

None
cmap Optional[Dict[str, str]]

Dictionary mapping categories in obs to specific colors. Overrides palette if provided. Defaults to None.

None
figsize Tuple[float, float]

Figure size. Defaults to (10, 6).

(10, 6)
xlabel Optional[str]

Label for the x-axis. Defaults to groupby.

None
ylabel str

Label for the y-axis. Defaults to 'Percentage (%)'.

'Percentage (%)'
title Optional[str]

Title of the plot. Defaults to None.

None
legend_title Optional[str]

Title for the legend. Defaults to the obs variable.

None
show bool

If True, shows the plot. If False, returns the figure. Defaults to True.

True
save Optional[str]

Path to save the figure. Defaults to None.

None
dpi int

DPI for the saved figure. Defaults to 150.

150
**kwargs

Additional arguments passed to plt.bar.

{}

Returns:

Type Description
Optional[Figure]

Optional[plt.Figure]: The matplotlib Figure object if show is False, otherwise None.

Raises:

Type Description
ValueError

If groupby or obs are not found in adata.obs or if both palette and cmap are provided.

Source code in src/quicat/plotting/_stacked_barplot.py
def stacked_barplot(
    adata: anndata.AnnData,
    groupby: str,
    obs_key: str,
    palette: Optional[Union[str, List[str]]] = None,
    cmap: Optional[Dict[str, str]] = None,
    figsize: Tuple[float, float] = (10, 6),
    xlabel: Optional[str] = None,
    ylabel: str = "Percentage (%)",
    title: Optional[str] = None,
    legend_title: Optional[str] = None,
    show: bool = True,
    save: Optional[str] = None,
    dpi: int = 150,
    **kwargs,
) -> Optional[plt.Figure]:
    """
    Creates a stacked bar plot showing the percentage distribution of one categorical variable over another.

    Parameters:
        adata (anndata.AnnData): The input AnnData object containing observations.
        groupby (str): The categorical variable in `adata.obs` used for grouping on the x-axis.
        obs_key (str): The categorical variable in `adata.obs` to stack.
        palette (Union[str, List[str]], optional): Seaborn palette or list of colors. Defaults to None.
        cmap (Optional[Dict[str, str]], optional): Dictionary mapping categories in `obs` to specific colors.
            Overrides `palette` if provided. Defaults to None.
        figsize (Tuple[float, float], optional): Figure size. Defaults to (10, 6).
        xlabel (Optional[str], optional): Label for the x-axis. Defaults to `groupby`.
        ylabel (str, optional): Label for the y-axis. Defaults to 'Percentage (%)'.
        title (Optional[str], optional): Title of the plot. Defaults to None.
        legend_title (Optional[str], optional): Title for the legend. Defaults to the `obs` variable.
        show (bool, optional): If True, shows the plot. If False, returns the figure. Defaults to True.
        save (Optional[str], optional): Path to save the figure. Defaults to None.
        dpi (int, optional): DPI for the saved figure. Defaults to 150.
        **kwargs: Additional arguments passed to `plt.bar`.

    Returns:
        Optional[plt.Figure]: The matplotlib Figure object if `show` is False, otherwise None.

    Raises:
        ValueError: If `groupby` or `obs` are not found in `adata.obs` or if both `palette` and `cmap` are provided.
    """
    if groupby not in adata.obs.columns:
        e = f"'{groupby}' not found in adata.obs columns."
        raise ValueError(e)
    if obs_key not in adata.obs.columns:
        e = f"'{obs_key}' not found in adata.obs columns."
        raise ValueError(e)
    if cmap is not None and palette is not None:
        e = "Specify either 'cmap' or 'palette', not both."
        raise ValueError(e)

    data = adata.obs[[groupby, obs_key]]
    pivot_table = data.pivot_table(index=groupby, columns=obs_key, aggfunc="size", fill_value=0)
    relative_abundances = pivot_table.div(pivot_table.sum(axis=1), axis=0) * 100

    fig, ax = plt.subplots(figsize=figsize)

    if palette is None and cmap is None:
        uns_palette = adata.uns.get(f"{obs_key}_colors", None)
        if uns_palette is not None and isinstance(uns_palette, list):
            used_palette = uns_palette
        else:
            used_palette = sns.color_palette("Set2", n_colors=len(relative_abundances.columns))

    elif cmap is not None:
        categories = relative_abundances.columns
        missing_categories = set(categories) - set(cmap.keys())
        if missing_categories:
            e = f"The following categories are missing in cmap: {missing_categories}"
            raise ValueError(e)
        used_palette = cmap
    else:
        used_palette = list(sns.color_palette(palette, n_colors=len(relative_abundances.columns)))

    bottom = pd.Series([0] * relative_abundances.shape[0], index=relative_abundances.index)
    for i, col in enumerate(relative_abundances.columns):
        color = used_palette[i] if isinstance(used_palette, list) else used_palette[col]
        ax.bar(relative_abundances.index, relative_abundances[col], bottom=bottom, color=color, label=col, **kwargs)
        bottom += relative_abundances[col]

    ax.set_xlabel(xlabel if xlabel else groupby)
    ax.set_ylabel(ylabel)

    if title:
        ax.set_title(title)

    _, labels = ax.get_legend_handles_labels()
    if isinstance(used_palette, list):
        custom_handles = [
            mlines.Line2D([], [], color=color, marker="o", linestyle="None", markersize=10) for color in used_palette
        ]
    elif isinstance(used_palette, dict):
        custom_handles = [
            mlines.Line2D([], [], color=color, marker="o", linestyle="None", markersize=10)
            for color in used_palette.values()
        ]
    ax.legend(
        handles=custom_handles,
        labels=labels,
        title=legend_title if legend_title else obs_key,
        bbox_to_anchor=(1.05, 1),
        loc="upper left",
        alignment="left",
        frameon=False,
    )

    plt.tight_layout()
    if save:
        fig.savefig(save, bbox_inches="tight", dpi=dpi)
    if show:
        plt.show()
        return None
    else:
        return fig

quicat.pl.boxplot(adata, groupby, obs_key, hue=None, palette=None, cmap=None, figsize=(10, 6), xlabel=None, ylabel=None, title=None, legend_title=None, show=True, save=None, dpi=150, **kwargs)

Creates a box plot showing the distribution of a variable across categories.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object containing observations.

required
groupby str

The categorical variable in adata.obs used for grouping and color encoding on the x-axis.

required
obs_key str

The variable in adata.obs to plot on the y-axis.

required
hue str

The categorical variable in adata.obs used for color encoding. Defaults to None.

None
palette Union[str, List[str], Dict[str, str]]

Seaborn palette name, list of colors, or a dictionary mapping categories to colors. Defaults to None.

None
cmap Optional[Dict[str, str]]

Dictionary mapping categories in hue to specific colors. Overrides palette if provided. Defaults to None.

None
figsize Tuple[float, float]

Figure size. Defaults to (10, 6).

(10, 6)
xlabel Optional[str]

Label for the x-axis. Defaults to groupby.

None
ylabel Optional[str]

Label for the y-axis. Defaults to obs_key.

None
title Optional[str]

Title of the plot. Defaults to None.

None
legend_title Optional[str]

Title for the legend. Defaults to the hue variable.

None
show bool

If True, shows the plot. If False, returns the figure. Defaults to True.

True
save Optional[str]

Path to save the figure. Defaults to None.

None
dpi int

DPI for the saved figure. Defaults to 150.

150
**kwargs

Additional arguments passed to sns.boxplot.

{}

Returns:

Type Description
Optional[Figure]

Optional[plt.Figure]: The matplotlib Figure object if show is False, otherwise None.

Raises:

Type Description
ValueError

If groupby, obs_key, or hue are not found in adata.obs or if both palette and cmap are provided.

Source code in src/quicat/plotting/_boxplot.py
def boxplot(
    adata: anndata.AnnData,
    groupby: str,
    obs_key: str,
    hue: Optional[str] = None,
    palette: Optional[Union[str, List[str], Dict[str, str]]] = None,
    cmap: Optional[Dict[str, str]] = None,
    figsize: Tuple[float, float] = (10, 6),
    xlabel: Optional[str] = None,
    ylabel: Optional[str] = None,
    title: Optional[str] = None,
    legend_title: Optional[str] = None,
    show: bool = True,
    save: Optional[str] = None,
    dpi: int = 150,
    **kwargs,
) -> Optional[plt.Figure]:
    """
    Creates a box plot showing the distribution of a variable across categories.

    Parameters:
        adata (anndata.AnnData): The input AnnData object containing observations.
        groupby (str): The categorical variable in `adata.obs` used for grouping and color encoding on the x-axis.
        obs_key (str): The variable in `adata.obs` to plot on the y-axis.
        hue (str, optional): The categorical variable in `adata.obs` used for color encoding. Defaults to None.
        palette (Union[str, List[str], Dict[str, str]], optional): Seaborn palette name, list of colors, or a dictionary mapping categories to colors.
            Defaults to None.
        cmap (Optional[Dict[str, str]], optional): Dictionary mapping categories in `hue` to specific colors.
            Overrides `palette` if provided. Defaults to None.
        figsize (Tuple[float, float], optional): Figure size. Defaults to (10, 6).
        xlabel (Optional[str], optional): Label for the x-axis. Defaults to `groupby`.
        ylabel (Optional[str], optional): Label for the y-axis. Defaults to `obs_key`.
        title (Optional[str], optional): Title of the plot. Defaults to None.
        legend_title (Optional[str], optional): Title for the legend. Defaults to the `hue` variable.
        show (bool, optional): If True, shows the plot. If False, returns the figure. Defaults to True.
        save (Optional[str], optional): Path to save the figure. Defaults to None.
        dpi (int, optional): DPI for the saved figure. Defaults to 150.
        **kwargs: Additional arguments passed to `sns.boxplot`.

    Returns:
        Optional[plt.Figure]: The matplotlib Figure object if `show` is False, otherwise None.

    Raises:
        ValueError: If `groupby`, `obs_key`, or `hue` are not found in `adata.obs` or if both `palette` and `cmap` are provided.
    """
    if groupby not in adata.obs.columns:
        e = f"'{groupby}' not found in adata.obs columns."
        raise ValueError(e)
    if obs_key not in adata.obs.columns:
        e = f"'{obs_key}' not found in adata.obs columns."
        raise ValueError(e)
    if hue is not None and hue not in adata.obs.columns:
        e = f"'{hue}' not found in adata.obs columns."
        raise ValueError(e)
    if cmap is not None and palette is not None:
        e = "Specify either 'cmap' or 'palette', not both."
        raise ValueError(e)

    columns = [groupby, obs_key] + ([hue] if hue else [])
    data = adata.obs[columns].copy()

    used_palette = None
    if palette is None and cmap is None:
        if hue is not None:
            uns_palette = adata.uns.get(f"{hue}_colors", None)
            categories = data[hue].unique()
            if uns_palette is not None and isinstance(uns_palette, list):
                if len(uns_palette) >= len(categories):
                    used_palette = dict(zip(categories, uns_palette))
                else:
                    used_palette_colors = sns.color_palette("Set2", n_colors=len(categories))
                    used_palette = dict(zip(categories, used_palette_colors))
            else:
                used_palette_colors = sns.color_palette("Set2", n_colors=len(categories))
                used_palette = dict(zip(categories, used_palette_colors))
        else:
            used_palette = sns.color_palette("Set2")
    elif cmap is not None:
        used_palette = cmap
        if hue is not None:
            categories = data[hue].unique()
            missing_categories = set(categories) - set(used_palette.keys())
            if missing_categories:
                e = f"The following categories are missing in cmap: {missing_categories}"
                raise ValueError(e)
    else:
        # When palette is specified
        if hue is not None:
            categories = data[hue].unique()
            if isinstance(palette, str):
                used_palette_colors = sns.color_palette(palette, n_colors=len(categories))
                used_palette = dict(zip(categories, used_palette_colors))
            elif isinstance(palette, list):
                if len(palette) < len(categories):
                    e = "Not enough colors in the palette for the number of categories."
                    raise ValueError(e)
                used_palette = dict(zip(categories, palette))
            elif isinstance(palette, dict):
                used_palette = palette
            else:
                e = "Invalid type for 'palette'. Must be a string, list, or dictionary."
                raise ValueError(e)
        else:
            # No hue specified, use the palette directly
            if isinstance(palette, str):
                used_palette = sns.color_palette(palette)
            elif isinstance(palette, list):
                used_palette = palette
            else:
                e = "Invalid type for 'palette'. Must be a string or list when 'hue' is None."
                raise ValueError(e)

    if hue is not None and "dodge" not in kwargs:
        kwargs["dodge"] = False

    fig, ax = plt.subplots(figsize=figsize)
    sns.boxplot(
        data=data,
        x=groupby,
        y=obs_key,
        hue=hue,
        palette=used_palette if hue else None,
        ax=ax,
        showfliers=False,
        **kwargs,
    )

    ax.set_xlabel(xlabel if xlabel else groupby)
    ax.set_ylabel(ylabel if ylabel else obs_key)
    if title:
        ax.set_title(title)

    if hue is not None:
        handles, labels = ax.get_legend_handles_labels()
        # Remove duplicate labels and handles
        unique = dict(zip(labels, handles))
        labels = list(unique.keys())
        handles = list(unique.values())
        custom_handles = [
            Line2D(
                [0],
                [0],
                color=used_palette[label],
                marker="o",
                linestyle="None",
                markersize=10,
            )
            for label in labels
        ]
        if legend_title is not None:
            ax.legend(custom_handles, labels, title=legend_title, frameon=False)
        else:
            ax.legend(custom_handles, labels, title=hue, frameon=False)
    else:
        if ax.legend_ is not None:
            ax.legend_.remove()

    ax.legend(
        handles=custom_handles,
        labels=labels,
        title=legend_title if legend_title else obs_key,
        bbox_to_anchor=(1.05, 1),
        loc="upper left",
        alignment="left",
        frameon=False,
    )

    plt.tight_layout()
    if save:
        fig.savefig(save, bbox_inches="tight", dpi=dpi)
    if show:
        plt.show()
        return None
    else:
        return fig