Banner

Data science tools

Tutorial on how to make beautiful plots with Python

Star
python matplotlib seaborn

By Afshine Amidi and Shervine Amidi

Motivation

The Department of Transportation publicly released a dataset that lists flights that occurred in 2015, along with specificities such as delays, flight time and other information. Our previous post detailed the best practices to manipulate data.

This tutorial aims at showing good practices to visualize data using Python's most popular libraries. The following are covered:

import matplotlib.pyplot as plt
import seaborn as sns

Temporal plots

Evolution of number of flights

We want to plot the temporal evolution of flights of major US airlines:

The code used to produce the plot is shown below.

# data (date, airline, nb_flights)
# Style
sns.set(font_scale=1)
sns.set_style("whitegrid")

# Plot
fig, ax = plt.subplots(figsize=(10,3))
sns.lineplot(data=data, x='date'y='nb_flights'hue='airline')
ax.axvline(x='2015-07-01'linestyle='--'color='grey')

# Axes
ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday=1))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
plt.xticks(rotation=15)
plt.ylim(0,4000)

# Legend
handles, labels = ax.get_legend_handles_labels()
labels[0] = 'Airline'
ax.legend(handles, labels, bbox_to_anchor=(1.04,0.5), loc="center left"borderaxespad=0frameon=0)
plt.suptitle('Number of flights in the US in 2015'x=0.28y=1.05)
plt.title('Top airlines'loc='left')
ax.text(x=0.75y=-0.245s='Source: publicly available data from DoT',
      fontsize=8transform=ax.transAxes)
plt.xlabel('Time');plt.ylabel('Number of daily flights')

Evolution of traffic per airport

We want to plot temporal evolution of traffic in the top Hawaiian airports:

The code used to produce the plot is shown below.

# data (date, airport, nb_flights), time_periods (xmin, xmax, ymin, ymax, name, color)
# Style
sns.set(font_scale=1)
sns.set_style("whitegrid")

# Plot
fig, ax = plt.subplots(figsize=(10,3))
sns.lineplot(data=data, x='date'y='nb_flights'hue='airport')
for i in range(len(time_periods)):
    ax.axvspan(time_periods.loc[i, 'xmin'], time_periods.loc[i, 'xmax'],
               facecolor=time_periods.loc[i, 'color'], alpha=0.2)

# Axes
ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday=1))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
plt.xticks(rotation=15)
plt.ylim(0,150)

# Legend
handles, labels = ax.get_legend_handles_labels()
labels[0] = 'Airport'
ax.legend(handles, labels, bbox_to_anchor=(1.04,0.5), loc="center left"borderaxespad=0frameon=0)
plt.title('Top 5 Hawaii airports'loc='left')
plt.suptitle('Number of flights departing from Hawaiian airports in 2015'x=0.38y=1.05)
ax.text(x=0.75y=-0.245s='Source: publicly available data from DoT',
        fontsize=8transform=ax.transAxes)
plt.xlabel('Time');plt.ylabel('Number of daily flights')

Visualizing across multiple dimensions

Delay by hour of day for top airlines

We can also plot the same data from a different angle with a heat map:

The code used to produce the plot is shown below.

# Plot
fig, ax = plt.subplots(figsize=(15,5))
sns.heatmap(data, cmap='RdBu_r'ax=ax, linecolor='black'linewidth=0.01)

# Legend
ax.invert_yaxis()
ax.set_xlabel('Hour of day')
ax.set_ylabel('')

Customizing visualization

Scatter plot of airports with delays

We want to plot:

The code used to produce the plot is shown below.

# Style
sns.set(font_scale=1)
sns.set_style("whitegrid")

# Plot
fig, ax = plt.subplots(figsize=(10,3))
sns.scatterplot(data=data, x='weather'y='perc_delay'hue='region'size='nb_flights',
                sizes=(50,200))

# Axes
ax.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=1decimals=0))
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1decimals=0))

# Legend
handles, labels = ax.get_legend_handles_labels()
labels[0] = 'Region';labels[5] = 'Number of flights'
ax.legend(handles, labels, bbox_to_anchor=(1.04,0.5), loc="center left"borderaxespad=0frameon=0)
plt.suptitle('Percentage of delayed flights vs likelihood of weather causing the delay',
             x=0.44y=1.05)
plt.title('Top US airports across 2015'loc='left')
ax.text(x=0.75y=-0.21s='Source: publicly available data from DoT',
        fontsize=8transform=ax.transAxes)
plt.xlabel('Likelihood weather is causing delay');plt.ylabel('Percentage of delayed flights')

Percentage of delayed flights with volume

We want to plot:

The code used to produce the plot is shown below.

# Style
sns.set(font_scale=1.3)
sns.set_style("whitegrid")

# Plot
fig, ax = plt.subplots(figsize=(15,6))

g1 = sns.lineplot(data=data, x='date'y='nb_flights'color='green')
plt.xlabel('Time')
plt.ylabel('Number of daily flights')
plt.ylim(0,20000)
ax2 = plt.twinx()
g2 = sns.lineplot(data=data, x='date'y='perc_delay'color='red'ax=ax2)
plt.ylabel('Percentage of delayed flights')
plt.ylim(0,1)

# Legend
plt.title('Number of flights in the US in 2015'loc='left')
ax.text(x=0.7y=-0.13s='Source: publicly available data from DoT',
        fontsize=14transform=ax.transAxes)

Conclusion

Matplotlib has many features, but is less intuitive than R's ggplot2.


You may also like...

Data visualization with Python
  • • Scatterplots, line plots, histograms
  • • Boxplots, maps
  • • Customized legend
Data manipulation with Python
  • • Filtering
  • • Types of joins
  • • Aggregations, window functions
  • • Data frame transformation
Data manipulation with Python
  • • Detailed example on how to process data efficiently with pandas, numpy, datetime