Skip to content
Snippets Groups Projects
plot_utils.py 515 B
Newer Older
#!/usr/bin/env python3

import numpy as np
from matplotlib import cm, colors


def generate_palette_map(df):
    """Create a palette map based on the strains in a dataframe"""
    strain_list = np.unique(df.index.get_level_values("strain"))
    palette_cm = cm.get_cmap("Set1", len(strain_list) + 1)
    palette_rgb = [
        colors.rgb2hex(palette_cm(index / len(strain_list))[:3])
        for index, _ in enumerate(strain_list)
    ]
    palette_map = dict(zip(strain_list, palette_rgb))
    return palette_map