from matplotlib import pyplot as plt
from snippets.plot import label_axes

fig, axes = plt.subplots(2, 2)
label_axes(axes[0])
label_axes(axes[1], label_offset=2)