Skip to content
Snippets Groups Projects
Commit 7d4ebb86 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add MultiGrouper

parent f3ee27ba
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3
from typing import Union, List, Dict
from abc import ABC, abstractmethod, abstractproperty
from pathlib import Path
from pathos.multiprocessing import Pool
from collections import Counter
import re
import h5py
import numpy as np
import pandas as pd
from p_tqdm import p_map
import matplotlib.pyplot as plt
import seaborn as sns
from agora.io.signal import Signal
......@@ -22,6 +27,7 @@ class Grouper(ABC):
def __init__(self, dir):
path = Path(dir)
self.name = path.name
assert path.exists(), "Dir does not exist"
self.files = list(path.glob("*.h5"))
assert len(self.files), "No valid h5 files in dir"
......@@ -223,3 +229,96 @@ def concat_signal_ind(path, group_names, group, signal, mode="retained", **kwarg
return combined
# except:
# return None
class MultiGrouper:
def __init__(self, source: Union[str, list]):
if isinstance(source, str):
source = Path(source)
self.exp_dirs = list(source.glob("*"))
else:
self.exp_dirs = [Path(x) for x in source]
self.groupers = [NameGrouper(d) for d in self.exp_dirs]
for group in self.groupers:
group.load_signals()
@property
def siglist(self):
for gpr in self.groupers:
print(gpr.siglist_grouped)
@property
def sigtable(self):
"""
Generate a table containing the number of datasets for each signal and experiment
"""
def regex_cleanup(x):
x = re.sub(r"\/extraction\/", "", x)
x = re.sub(r"\/postprocessing\/", "", x)
x = re.sub(r"\/np_max", "", x)
return x
if not hasattr(self, "_sigtable"):
raw_mat = [
[s.siglist for s in gpr.signals.values()] for gpr in self.groupers
]
siglist_grouped = [Counter([x for y in grp for x in y]) for grp in raw_mat]
nexps = len(siglist_grouped)
sigs_idx = list(set([y for x in siglist_grouped for y in x.keys()]))
sigs_idx = [regex_cleanup(x) for x in sigs_idx]
nsigs = len(sigs_idx)
d = {}
sig_matrix = np.zeros((nsigs, nexps))
for i, c in enumerate(siglist_grouped):
for k, v in c.items():
sig_matrix[sigs_idx.index(regex_cleanup(k)), i] = v
sig_matrix[sig_matrix == 0] = np.nan
self._sigtable = pd.DataFrame(
sig_matrix, index=sigs_idx, columns=[x.name for x in mg.exp_dirs]
)
return self._sigtable
def sigtable_plot(self):
ax = sns.heatmap(self.sigtable, cmap="viridis")
ax.set_xticklabels(
ax.get_xticklabels(), rotation=10, ha="right", rotation_mode="anchor"
)
plt.show()
def aggregate_signal(
self,
signals: Union[str, list],
**kwargs,
) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]:
if isinstance(signals, str):
signals = [signals]
sigs = {s: [] for s in signals}
for s in signals:
for grp in self.groupers:
try:
sigset = grp.concat_signal(s)
new_idx = pd.MultiIndex.from_tuples(
[(grp.name, *x) for x in sigset.index],
names=("experiment", *sigset.index.names),
)
sigset.index = new_idx
sigs[s].append(sigset)
except Exception as e:
print("Grouper {} failed: {}".format(grp.name, e))
# raise (e)
concated = {
name: pd.concat(multiexp_sig) for name, multiexp_sig in sigs.items()
}
if len(concated) == 1:
concated = list(concated.values())[0]
return concated
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment