Programming basics for Biostatistics 6099

seaborn: statistical data visualization

Zhiguang Huo (Caleb)

Thursday Nov 16th, 2023

Outlines

Get started

import seaborn as sns

Overview of the seaborn functionality

tips dataset

tips = sns.load_dataset("tips")
tips.head()
##    total_bill   tip     sex smoker  day    time  size
## 0       16.99  1.01  Female     No  Sun  Dinner     2
## 1       10.34  1.66    Male     No  Sun  Dinner     3
## 2       21.01  3.50    Male     No  Sun  Dinner     3
## 3       23.68  3.31    Male     No  Sun  Dinner     2
## 4       24.59  3.61  Female     No  Sun  Dinner     4
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

relplot (scatter plot)

sns.relplot(data=tips,x="total_bill", y="tip", kind='scatter')

# sns.scatterplot(data=tips, x="total_bill", y="tip", )

color and style

sns.relplot(data=tips,x="total_bill", y="tip", hue = "smoker")

sns.relplot(data=tips,x="total_bill", y="tip", hue = "smoker", style="smoker")

size and transparency

sns.relplot(data=tips,x="total_bill", y="tip", size = "size")

sns.relplot(data=tips,x="total_bill", y="tip", size = "size", alpha = 0.5)

multiple subfigures

sns.relplot(
    data=tips,
    x="total_bill", y="tip", col="time",
    hue="smoker", style="smoker", size="size",
)

scatter plot with linear regression fitting

sns.lmplot(data=tips, x="total_bill", y="tip", col="time", hue="smoker")

styles

sns.set_style("whitegrid")
sns.relplot(data=tips,x="total_bill", y="tip", kind='scatter')

alternative function:

sns.set_theme(style = "whitegrid")

Scaling Plots – set_context

sns.set_style("whitegrid")
sns.set_context("talk")
sns.relplot(data=tips,x="total_bill", y="tip", kind='scatter')

alternative function:

sns.set_theme(context = "talk, style = "whitegrid")

fmri dataset

fmri = sns.load_dataset("fmri")
fmri.head()
##   subject  timepoint event    region    signal
## 0     s13         18  stim  parietal -0.017552
## 1      s5         14  stim  parietal -0.080883
## 2     s12         18  stim  parietal -0.081033
## 3     s11         18  stim  parietal -0.046134
## 4     s10         18  stim  parietal -0.037970
fmri_sub13 = fmri[(fmri["subject"]=="s13")]
fmri_sub13_stim = fmri_sub13[(fmri_sub13["event"]=="stim")]
fmri_sub13_stim_parietal = fmri_sub13_stim[(fmri_sub13_stim["region"]=="parietal")]
fmri_sub13_stim_parietal.head()
##    subject  timepoint event    region    signal
## 0      s13         18  stim  parietal -0.017552
## 15     s13         17  stim  parietal -0.008265
## 29     s13         16  stim  parietal -0.002856
## 43     s13         15  stim  parietal -0.010971
## 57     s13         14  stim  parietal -0.033713

relplot (line plot)

sns.relplot(data=fmri_sub13_stim_parietal, kind="line", x="timepoint", y="signal")

# sns.lineplot(data=fmri_sub13_stim_parietal, x="timepoint", y="signal")

combine scatter and line plots

fig, ax = plt.subplots(figsize=(6, 4))
sns.scatterplot(data=fmri_sub13_stim_parietal, x="timepoint", y="signal", ax=ax)
sns.lineplot(data=fmri_sub13_stim_parietal, x="timepoint", y="signal", ax=ax)

relplot (line plot)

sns.relplot(data=fmri_sub13_stim, kind="line", x="timepoint", y="signal", hue="region", )

sns.relplot(data=fmri_sub13, kind="line", x="timepoint", y="signal", hue="region", style = "event", )

sns.relplot(data=fmri_sub13, kind="line", x="timepoint", y="signal", hue="region", col="event", )

multiple subjects

sns.relplot(data=fmri, kind="line", x="timepoint", y="signal", hue="region", col="event", )

Exercise

fmri_sub = fmri[(fmri["event"]=="stim") & (fmri["region"]=="parietal")]
fmri_sub.head()
##   subject  timepoint event    region    signal
## 0     s13         18  stim  parietal -0.017552
## 1      s5         14  stim  parietal -0.080883
## 2     s12         18  stim  parietal -0.081033
## 3     s11         18  stim  parietal -0.046134
## 4     s10         18  stim  parietal -0.037970
sns.relplot(data=fmri_sub, kind="line", x="timepoint", y="signal", hue="subject", )

histogram

# sns.displot(data=tips, x="total_bill")
sns.displot(data=tips, x="total_bill", col="time", kde=True)

## density function only 
## sns.displot(data=tips, x="total_bill", col="time", kind = "kde")

empirical cumulative distribution function

sns.displot(data=tips, kind="ecdf", x="total_bill", col="time", hue="smoker")

# sns.displot(data=tips, kind="ecdf", x="total_bill", col="time", hue="smoker", rug=True)

Visualizing categorical data

catplot (strip plot)

sns.catplot(data=tips, kind="strip", x="day", y="total_bill", hue="smoker")

sns.catplot(data=tips, kind="strip", x="day", y="total_bill", hue="smoker", dodge=True)
sns.stripplot(data=tips, x="day", y="total_bill", hue="smoker")

catplot (swarm plot)

sns.catplot(data=tips, kind="swarm", x="day", y="total_bill", hue="smoker", dodge=True)

sns.catplot(data=tips, kind="swarm", x="day", y="total_bill", hue="smoker", )
sns.swarmplot(data=tips, x="day", y="total_bill", hue="smoker")

catplot (boxplot)

sns.catplot(data=tips, kind="box", x="day", y="total_bill", hue="smoker")

# sns.boxplot(data=tips, x="day", y="total_bill", hue="smoker")

overlay box plot and jitter plot

fig, ax = plt.subplots(figsize=(6, 4))
sns.stripplot(data=tips, x="day", y="total_bill", hue="smoker", dodge=True, ax=ax, legend=False)
sns.boxplot(data=tips, x="day", y="total_bill", hue="smoker", ax=ax)

catplot (violin)

sns.catplot(data=tips, kind="violin", x="day", y="total_bill", hue="smoker")

# sns.violinplot(data=tips, x="day", y="total_bill", hue="smoker")

catplot (pointplot)

sns.catplot(data=tips, kind="point", x="day", y="total_bill", hue="smoker")

# sns.pointplot(data=tips, x="day", y="total_bill", hue="smoker")

catplot (barplot)

sns.catplot(data=tips, kind="bar", x="day", y="total_bill", hue="smoker")

# sns.barplot(data=tips, x="day", y="total_bill", hue="smoker")
# sns.barplot(data=tips, x="day", y="total_bill", hue="smoker", ci=None)
# sns.barplot(data=tips, y="day", x="total_bill", hue="smoker") ## horizontal barplot

catplot (countplot)

sns.catplot(data=tips, kind="count", x="day", hue="smoker")

# sns.countplot(data=tips, x="day", hue="smoker")

iris dataset

iris = sns.load_dataset("iris")
iris.head()
##    sepal_length  sepal_width  petal_length  petal_width species
## 0           5.1          3.5           1.4          0.2  setosa
## 1           4.9          3.0           1.4          0.2  setosa
## 2           4.7          3.2           1.3          0.2  setosa
## 3           4.6          3.1           1.5          0.2  setosa
## 4           5.0          3.6           1.4          0.2  setosa

pairplot

p1 = sns.pairplot(iris, hue="species")
#sns.pairplot(iris)
plt.show()

heatmap

iris_sub = iris.drop(columns="species")
sns.heatmap(iris_sub)

Reference

https://seaborn.pydata.org/tutorial