import textwrap
import traceback
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from seaborn import diverging_palette, stripplot, violinplot
from dae.pheno.utils.lin_regress import LinearRegression
from dae.variants.attributes import Role, Sex, Status
mpl.use("PDF")
plt.ioff()
MAX_CATEGORIES_COUNT = 12
ROLES_COUNT_CUTOFF = 7
ROLES_GRAPHS_DEFINITION = {
"probands": [Role.prb],
"siblings": [Role.sib],
"parents": [Role.mom, Role.dad],
"grandparents": [
Role.paternal_grandfather,
Role.paternal_grandmother,
Role.maternal_grandfather,
Role.maternal_grandmother,
],
"parental siblings": [
Role.paternal_uncle,
Role.paternal_aunt,
Role.maternal_uncle,
Role.maternal_aunt,
],
"step parents": [Role.step_mom, Role.step_dad],
"half siblings": [Role.paternal_half_sibling, Role.maternal_half_sibling],
"children": [Role.child],
}
[docs]
class GraphColumn:
"""Build a colum to produce a graph from it."""
def __init__(self, name, roles, status, df):
self.name = name
self.roles = roles
self.status = status
self.df = df
[docs]
def all_count(self):
return self.df.shape[0]
[docs]
def males_count(self):
return self.df[self.df.sex == Sex.male].shape[0]
[docs]
def females_count(self):
return self.df[self.df.sex == Sex.female].shape[0]
@property
def label(self):
return self.name + "\n" + self.status.name
[docs]
@staticmethod
def build(df, role_name, role_subroles, status):
"""Construct a graph column object."""
roles = role_subroles
default_name = ", ".join([role.name for role in roles])
label = role_name if role_name != "" else default_name
df_roles = df[df.role.isin(roles)]
df_roles_status = df_roles[df_roles.status == status]
return GraphColumn(label, roles, status, df_roles_status)
[docs]
def names(col1, col2):
name1 = col1.split(".")[-1]
name2 = col2.split(".")[-1]
return (name1, name2)
[docs]
def male_female_colors():
[color_male, color_female] = gender_palette()
return color_male, color_female
[docs]
def male_female_legend(color_male, color_female, ax=None):
"""Consturct a legend for female graph."""
if ax is None:
ax = plt.gca()
# pylint: disable=import-outside-toplevel
import matplotlib.patches as mpatches
male_patch = mpatches.Patch(color=color_male, label="M")
female_patch = mpatches.Patch(color=color_female, label="F")
ax.legend(handles=[male_patch, female_patch], title="Sex", loc="upper right")
[docs]
def draw_linregres(df, col1, col2, jitter: int | None = None, ax=None):
"""Draw a graph display linear regression between two columns."""
if ax is None:
ax = plt.gca()
df = df.dropna()
dmale = df[df.sex == Sex.male]
dfemale = df[df.sex == Sex.female]
name1, name2 = names(col1, col2)
ax.set_xlabel(name1)
ax.set_ylabel(name2)
try:
male_x = dmale[[col1]]
y = dmale[col2] # pylint: disable=invalid-name
res_male = LinearRegression().fit(male_x.to_numpy(), y)
except ValueError:
traceback.print_exc()
res_male = None
try:
female_x = dfemale[[col1]]
y = dfemale[col2] # pylint: disable=invalid-name
res_female = LinearRegression().fit(female_x.to_numpy(), y)
except ValueError:
traceback.print_exc()
res_female = None
if jitter is None:
jmale1 = jmale2 = np.zeros(len(dmale[col1]))
jfemale1 = jfemale2 = np.zeros(len(dfemale[col1]))
else:
assert jitter is not None
# pylint: disable=invalid-unary-operand-type
jmale1 = np.random.uniform(-jitter, jitter, len(dmale[col1]))
jmale2 = np.random.uniform(-jitter, jitter, len(dmale[col2]))
jfemale1 = np.random.uniform(-jitter, jitter, len(dfemale[col1]))
jfemale2 = np.random.uniform(-jitter, jitter, len(dfemale[col2]))
color_male, color_female = gender_palette_light()
ax.plot(
dmale[col1].values + jmale1,
dmale[col2].values + jmale2,
".",
color=color_male,
ms=4,
label="male",
)
ax.plot(
dfemale[col1].values + jfemale1,
dfemale[col2].values + jfemale2,
".",
color=color_female,
ms=4,
label="female",
)
color_male, color_female = gender_palette()
if res_male:
ax.plot(
dmale[col1].values,
res_male.predict(male_x), color=color_male,
)
if res_female:
ax.plot(
dfemale[col1].values,
res_female.predict(female_x), color=color_female,
)
male_female_legend(color_male, color_female, ax)
plt.tight_layout()
return res_male, res_female
[docs]
def draw_distribution(df, measure_id, ax=None):
"""Draw measure distribution."""
if ax is None:
ax = plt.gca()
color_male, color_female = male_female_colors()
ax.hist(
[
df[df.sex == Sex.female][measure_id],
df[df.sex == Sex.male][measure_id],
],
color=[color_female, color_male],
stacked=True,
bins=20,
normed=False,
)
male_female_legend(color_male, color_female, ax)
ax.set_xlabel(measure_id)
ax.set_ylabel("count")
plt.tight_layout()
[docs]
def column_counts(column):
"""Collect counts for a graph column."""
counts = {
"column_name": textwrap.fill(column.name, 9),
"column_status": "$\\it{" + column.status.name + "}$",
"column_total": column.all_count(),
"male_total": column.males_count(),
"female_total": column.females_count(),
}
return counts
[docs]
def role_labels(ordered_columns):
return [
"{column_name}\n{column_status}\n"
"all:{column_total:>4d}\n"
"M: {male_total:>4d}\n"
"F: {female_total:>4d}".format(**column_counts(col))
for col in ordered_columns
]
[docs]
def gender_palette_light():
palette = diverging_palette(240, 10, s=80, l=77, n=2)
return palette
[docs]
def gender_palette():
palette = diverging_palette(240, 10, s=80, l=50, n=2)
return palette
def _enumerate_by_count(df, column_name):
occurrence_counts = df[column_name].value_counts(dropna=False)
occurrence_ordered = occurrence_counts.index.values.tolist()
occurrences_map = {
value: number for (number, value) in enumerate(occurrence_ordered)
}
return (
df[column_name].apply(lambda x: occurrences_map[x]),
occurrence_ordered,
)
def _enumerate_by_natural_order(df, column_name):
values_domain = df[column_name].unique()
values_domain = sorted(values_domain, key=float)
values_map = {
value: number for (number, value) in enumerate(values_domain)
}
result = df[column_name].apply(lambda x: values_map[x])
return result, values_domain
[docs]
def draw_measure_violinplot(
df, measure_id, roles_definition=None, ax=None):
"""Draw a violin plot for a measure."""
if roles_definition is None:
roles_definition = ROLES_GRAPHS_DEFINITION
if ax is None:
ax = plt.gca()
df = df.copy()
fig = plt.gcf()
palette = gender_palette()
columns = get_columns_to_draw(roles_definition, df)
if len(columns) == 0:
return False
column_dfs = []
for column in columns:
column_df = column.df
column_df["column_name"] = column.label
column_dfs.append(column_df)
df_with_column_names = pd.concat(column_dfs)
assert df.shape[1] == df_with_column_names.shape[1] - 1
column_names = [column.label for column in columns]
set_figure_size(fig, len(columns))
violinplot(
data=df_with_column_names,
x="column_name",
y=measure_id,
hue="sex",
order=column_names,
hue_order=[Sex.male, Sex.female],
linewidth=1,
split=True,
density_norm="count",
common_norm=True,
palette=palette,
saturation=1,
)
df_with_column_names.sex = df_with_column_names.sex.apply(
lambda s: s if s != Sex.unspecified else Sex.male)
palette = gender_palette_light()
stripplot(
data=df_with_column_names,
x="column_name",
y=measure_id,
hue="sex",
order=column_names,
hue_order=[Sex.male, Sex.female],
jitter=0.025,
size=2,
palette=palette,
linewidth=0.1,
)
labels = role_labels(columns)
plt.xticks(list(range(len(labels))), labels)
ax.set_ylabel(measure_id)
ax.set_xlabel("role")
plt.tight_layout()
return True
[docs]
def get_columns_to_draw(roles, df):
"""Collect columns needed for graphs."""
columns = []
for role_name, role_subroles in roles.items():
for status in [Status.affected, Status.unaffected]:
columns.append(
GraphColumn.build(df, role_name, role_subroles, status),
)
dfs = [
column
for column in columns
if column.all_count() >= ROLES_COUNT_CUTOFF
]
return dfs
[docs]
def draw_categorical_violin_distribution(
df,
measure_id,
roles_definition=None,
ax=None,
numerical_categories=False,
max_categories=MAX_CATEGORIES_COUNT,
):
"""Draw violin distribution for categorical measures."""
if roles_definition is None:
roles_definition = ROLES_GRAPHS_DEFINITION
if ax is None:
ax = plt.gca()
if df.empty:
return False
df = df.copy()
color_male, color_female = male_female_colors()
numerical_measure_name = measure_id + "_numerical"
if not numerical_categories:
df[numerical_measure_name], values_domain = _enumerate_by_count(
df, measure_id,
)
else:
(
df[numerical_measure_name],
values_domain,
) = _enumerate_by_natural_order(df, measure_id)
values_domain = values_domain[:max_categories]
y_locations = np.arange(len(values_domain))
columns = get_columns_to_draw(roles_definition, df)
if len(columns) == 0:
return False
bin_edges = y_locations
centers = bin_edges
heights = 0.8
hist_range = (np.min(y_locations), np.max(y_locations))
binned_datasets = []
for column in columns:
df_role = column.df
df_male = df_role[df_role.sex == Sex.male]
df_female = df_role[df_role.sex == Sex.female]
male_data = df_male[numerical_measure_name].values
female_data = df_female[numerical_measure_name].values
binned_datasets.append(
[
np.histogram(d, range=hist_range, bins=len(y_locations))[0]
for d in [male_data, female_data]
],
)
binned_maximum = np.max(
[np.max([np.max(m), np.max(f)]) for (m, f) in binned_datasets],
)
x_locations = np.arange(
0, len(columns) * 2 * binned_maximum, 2 * binned_maximum,
)
_fig, ax = plt.subplots()
_fig.set_tight_layout(True)
set_figure_size(_fig, len(columns))
female_text_offeset = binned_maximum * 0.05
for count, (male, female) in enumerate(binned_datasets):
x_loc = x_locations[count]
lefts = x_loc - male
ax.barh(centers, male, height=heights, left=lefts, color=color_male)
ax.barh(
centers, female, height=heights, left=x_loc, color=color_female,
)
# pylint: disable=invalid-name
for y, (male_count, female_count) in enumerate(zip(male, female)):
ax.text(
x_loc - male_count,
y,
str(male_count),
horizontalalignment="center",
verticalalignment="bottom",
rotation=90,
rotation_mode="anchor",
)
ax.text(
x_loc + female_count + female_text_offeset,
y,
str(female_count),
horizontalalignment="center",
verticalalignment="top",
rotation=90,
rotation_mode="anchor",
)
ax.set_yticks(y_locations)
ax.set_yticklabels([
textwrap.fill(str(x), 20) if x is not None else
textwrap.fill("None", 20)
for x in values_domain
])
ax.set_xlim(2 * -binned_maximum, len(columns) * 2 * binned_maximum)
ax.set_ylim(-1, np.max(y_locations) + 1)
ax.set_ylabel(measure_id)
ax.set_xlabel("role")
labels = role_labels(columns)
plt.xticks(x_locations, labels)
male_female_legend(color_male, color_female, ax)
return True
[docs]
def draw_ordinal_violin_distribution(df, measure_id, ax=None):
if ax is None:
ax = plt.gca()
df = df.copy()
df[measure_id] = df[measure_id].apply(str)
return draw_categorical_violin_distribution(
df, measure_id, numerical_categories=True, ax=ax,
)