import textwrap
import traceback
from typing import Any, cast
import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from seaborn import diverging_palette, stripplot, violinplot
from dae.pheno.utils.lin_regress import LinearRegression
from dae.variants.attributes import Role, Sex, Status
mpl.use("PDF")
plt.ioff()
MAX_CATEGORIES_COUNT: int = 12
ROLES_COUNT_CUTOFF: int = 3
ROLES_GRAPHS_DEFINITION: dict[str, list[Role]] = {
"probands": [Role.prb],
"siblings": [Role.sib],
"parents": [Role.mom, Role.dad],
"grandparents": [
Role.paternal_grandfather,
Role.paternal_grandmother,
Role.maternal_grandfather,
Role.maternal_grandmother,
],
"parental siblings": [
Role.paternal_uncle,
Role.paternal_aunt,
Role.maternal_uncle,
Role.maternal_aunt,
],
"step parents": [Role.step_mom, Role.step_dad],
"half siblings": [Role.paternal_half_sibling, Role.maternal_half_sibling],
"children": [Role.child],
}
[docs]
class GraphColumn:
"""Build a column to produce a graph from it."""
def __init__(
self, name: str,
roles: list[Role],
status: Status,
df: pd.DataFrame,
) -> None:
self.name = name
self.roles = roles
self.status = status
self.df = df
[docs]
def all_count(self) -> int:
return self.df.shape[0]
[docs]
def males_count(self) -> int:
return self.df[self.df.sex == Sex.male].shape[0]
[docs]
def females_count(self) -> int:
return self.df[self.df.sex == Sex.female].shape[0]
@property
def label(self) -> str:
return self.name + "\n" + self.status.name
[docs]
@staticmethod
def build(
df: pd.DataFrame,
role_name: str,
role_subroles: list[Role],
status: Status,
) -> "GraphColumn":
"""Construct a graph column object."""
roles = role_subroles
default_name = ", ".join([role.name for role in roles])
label = role_name if role_name != "" else default_name
df_roles = df[df.role.isin(roles)]
df_roles_status = df_roles[df_roles.status == status]
return GraphColumn(label, roles, status, df_roles_status)
[docs]
def male_female_legend(
color_male: tuple[float, float, float],
color_female: tuple[float, float, float],
ax: Axes | None = None,
) -> None:
"""Construct a legend for female graph."""
if ax is None:
ax = plt.gca()
male_patch = mpatches.Patch(color=color_male, label="M")
female_patch = mpatches.Patch(color=color_female, label="F")
ax.legend(
handles=[male_patch, female_patch], title="Sex", loc="upper right",
)
[docs]
def draw_linregres(
df: pd.DataFrame,
col1: str,
col2: str,
jitter: int | None = None,
ax: Axes | None = None,
) -> tuple:
"""Draw a graph display linear regression between two columns."""
if ax is None:
ax = plt.gca()
df = df.dropna()
dmale = df[df.sex == Sex.male]
dfemale = df[df.sex == Sex.female]
name1 = col1.split(".")[-1]
name2 = col2.split(".")[-1]
ax.set_xlabel(name1)
ax.set_ylabel(name2)
male_x = dmale[[col1]]
male_y = dmale[col2]
if len(male_x) <= 2:
res_male: Any = None
else:
try:
res_male = LinearRegression().calc_regression(
male_x.to_numpy(), male_y)
except ValueError:
traceback.print_exc()
res_male = None
female_x = dfemale[[col1]]
female_y = dfemale[col2]
if len(female_x) <= 2:
res_female: Any = None
else:
try:
res_female = LinearRegression().calc_regression(
female_x.to_numpy(), female_y)
except ValueError:
traceback.print_exc()
res_female = None
jmale1: np.ndarray
jmale2: np.ndarray
jfemale1: np.ndarray
jfemale2: np.ndarray
if jitter is None:
jmale1 = jmale2 = np.zeros(len(dmale[col1]))
jfemale1 = jfemale2 = np.zeros(len(dfemale[col1]))
else:
assert jitter is not None
# pylint: disable=invalid-unary-operand-type
jmale1 = np.random.uniform(-jitter, jitter, len(dmale[col1]))
jmale2 = np.random.uniform(-jitter, jitter, len(dmale[col2]))
jfemale1 = np.random.uniform(-jitter, jitter, len(dfemale[col1]))
jfemale2 = np.random.uniform(-jitter, jitter, len(dfemale[col2]))
color_male, color_female = gender_palette_light()
ax.plot(
dmale[col1].to_numpy() + jmale1,
dmale[col2].to_numpy() + jmale2,
".",
color=color_male,
ms=4,
label="male",
)
ax.plot(
dfemale[col1].to_numpy() + jfemale1,
dfemale[col2].to_numpy() + jfemale2,
".",
color=color_female,
ms=4,
label="female",
)
color_male, color_female = gender_palette()
if res_male:
ax.plot(
dmale[col1].values, # type: ignore
res_male.predict(male_x.to_numpy()), color=color_male,
)
if res_female:
ax.plot(
dfemale[col1].values, # type: ignore
res_female.predict(female_x.to_numpy()), color=color_female,
)
male_female_legend(color_male, color_female, ax)
plt.tight_layout()
return res_male, res_female
[docs]
def column_counts(column: GraphColumn) -> dict[str, str | int]:
"""Collect counts for a graph column."""
return {
"column_name": textwrap.fill(column.name, 9),
"column_status": "$\\it{" + column.status.name + "}$",
"column_total": column.all_count(),
"male_total": column.males_count(),
"female_total": column.females_count(),
}
[docs]
def role_labels(ordered_columns: list[GraphColumn]) -> list[str]:
return [
"{column_name}\n{column_status}\n"
"all:{column_total:>4d}\n"
"M: {male_total:>4d}\n"
"F: {female_total:>4d}".format(**column_counts(col)) # type: ignore
for col in ordered_columns
]
[docs]
def gender_palette_light() -> Any:
return diverging_palette(240, 10, s=80, l=77, n=2)
[docs]
def gender_palette() -> Any:
return diverging_palette(240, 10, s=80, l=50, n=2)
def _enumerate_by_count(
df: pd.DataFrame,
column_name: str,
) -> tuple[pd.Series, list[str]]:
occurrence_counts = df[column_name].value_counts(dropna=False)
occurrence_ordered = cast(
list[str],
occurrence_counts.index.to_numpy().tolist())
occurrences_map = {
value: number for (number, value) in enumerate(occurrence_ordered)
}
return (
df[column_name].apply(lambda x: occurrences_map[x]),
occurrence_ordered,
)
def _enumerate_by_natural_order(
df: pd.DataFrame, column_name: str,
) -> tuple[pd.Series, list]:
values_domain = list(df[column_name].unique())
values_domain = sorted(values_domain, key=float)
values_map = {
value: number for (number, value) in enumerate(values_domain)
}
result = df[column_name].apply(lambda x: values_map[x])
return result, values_domain
[docs]
def draw_measure_violinplot(
df: pd.DataFrame,
measure_id: str,
roles_definition: dict[str, list[Role]] | None = None,
ax: Axes | None = None,
) -> bool:
"""Draw a violin plot for a measure."""
if roles_definition is None:
roles_definition = ROLES_GRAPHS_DEFINITION
if ax is None:
ax = plt.gca()
df = df.copy()
fig = plt.gcf()
palette = gender_palette()
columns = get_columns_to_draw(roles_definition, df)
if len(columns) == 0:
return False
column_dfs = []
for column in columns:
column_df = column.df
column_df["column_name"] = column.label
column_dfs.append(column_df)
df_with_column_names = pd.concat(column_dfs)
assert df.shape[1] == df_with_column_names.shape[1] - 1
column_names = [column.label for column in columns]
set_figure_size(fig, len(columns))
df_with_column_names.sex = df_with_column_names.sex.apply(
lambda s: s.name if s != Sex.unspecified else "M")
violinplot(
data=df_with_column_names,
x="column_name",
y=measure_id,
hue="sex",
order=column_names,
hue_order=[Sex.male.name, Sex.female.name],
linewidth=1,
split=True,
density_norm="count",
common_norm=True,
palette=palette,
saturation=1,
)
palette = gender_palette_light()
stripplot(
data=df_with_column_names,
x="column_name",
y=measure_id,
hue="sex",
order=column_names,
hue_order=[Sex.male.name, Sex.female.name],
jitter=0.025,
size=2,
palette=palette,
linewidth=0.1,
)
labels = role_labels(columns)
plt.xticks(list(range(len(labels))), labels)
ax.set_ylabel(measure_id)
ax.set_xlabel("role")
plt.tight_layout()
return True
[docs]
def get_columns_to_draw(
roles: dict[str, list[Role]],
df: pd.DataFrame,
) -> list[GraphColumn]:
"""Collect columns needed for graphs."""
columns = []
for role_name, role_subroles in roles.items():
columns.extend([
GraphColumn.build(df, role_name, role_subroles, status)
for status in [Status.affected, Status.unaffected]
])
return [
column
for column in columns
if column.all_count() >= ROLES_COUNT_CUTOFF
]
[docs]
def draw_categorical_violin_distribution(
df: pd.DataFrame,
measure_id: str,
*,
roles_definition: dict[str, list[Role]] | None = None,
ax: Axes | None = None,
numerical_categories: bool = False,
max_categories: int = MAX_CATEGORIES_COUNT,
) -> bool:
"""Draw violin distribution for categorical measures."""
if roles_definition is None:
roles_definition = ROLES_GRAPHS_DEFINITION
if ax is None:
ax = plt.gca()
if df.empty:
return False
df = df.copy()
color_male, color_female = gender_palette()
numerical_measure_name = measure_id + "_numerical"
if not numerical_categories:
df[numerical_measure_name], values_domain = _enumerate_by_count(
df, measure_id,
)
else:
(
df[numerical_measure_name],
values_domain,
) = _enumerate_by_natural_order(df, measure_id)
values_domain = values_domain[:max_categories]
bin_edges = np.arange(len(values_domain))
columns = get_columns_to_draw(roles_definition, df)
if len(columns) == 0:
return False
heights = 0.8
hist_range = (
float(np.min(bin_edges)),
float(np.max(bin_edges)),
)
binned_datasets = []
for column in columns:
male_data = column.df[
column.df.sex == Sex.male
][numerical_measure_name].to_numpy()
female_data = column.df[
column.df.sex == Sex.female
][numerical_measure_name].to_numpy()
binned_datasets.append(
[
np.histogram(d, range=hist_range, bins=len(bin_edges))[0]
for d in [male_data, female_data]
],
)
binned_maximum = np.max(
[np.max([np.max(m), np.max(f)]) for (m, f) in binned_datasets],
)
x_locations = np.arange(
0, len(columns) * 2 * binned_maximum, 2 * binned_maximum,
)
fig, ax = plt.subplots()
fig.set_layout_engine("tight")
set_figure_size(fig, len(columns))
female_text_offset = binned_maximum * 0.05
for count, (male, female) in enumerate(binned_datasets):
x_loc = x_locations[count]
lefts = x_loc - male
assert ax is not None
ax.barh(bin_edges, male, height=heights, left=lefts, color=color_male)
ax.barh(
bin_edges, female, height=heights, left=x_loc, color=color_female,
)
# pylint: disable=invalid-name
for y, (male_count, female_count) \
in enumerate(zip(male, female, strict=True)):
ax.text(
x_loc - male_count,
y,
str(male_count),
horizontalalignment="center",
verticalalignment="bottom",
rotation=90,
rotation_mode="anchor",
)
ax.text(
x_loc + female_count + female_text_offset,
y,
str(female_count),
horizontalalignment="center",
verticalalignment="top",
rotation=90,
rotation_mode="anchor",
)
assert ax is not None
ax.set_yticks(bin_edges)
ax.set_yticklabels([
textwrap.fill(str(x), 20) if x is not None else
textwrap.fill("None", 20)
for x in values_domain
])
ax.set_xlim(2 * -binned_maximum, len(columns) * 2 * binned_maximum)
ax.set_ylim(-1, np.max(bin_edges) + 1) # pyright: ignore
ax.set_ylabel(measure_id)
ax.set_xlabel("role")
plt.xticks(x_locations, role_labels(columns))
male_female_legend(color_male, color_female, ax)
return True
[docs]
def draw_ordinal_violin_distribution(
df: pd.DataFrame,
measure_id: str,
ax: Axes | None = None,
) -> bool:
"""Draw violin distribution for ordinal measures."""
if ax is None:
ax = plt.gca()
df = df.copy()
df[measure_id] = df[measure_id].apply(str)
return draw_categorical_violin_distribution(
df, measure_id, numerical_categories=True, ax=ax,
)