From 0cb3d522570507cff76646d4943e396fe1388f09 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Wed, 16 Jun 2021 20:43:19 +0100
Subject: [PATCH] extend Writer

Former-commit-id: 9a3489325bf848362c0cbd27cc8949cca029376e
---
 core/io/writer.py | 41 ++++++++++++++++++++++++++++++++++++-----
 1 file changed, 36 insertions(+), 5 deletions(-)

diff --git a/core/io/writer.py b/core/io/writer.py
index 480c7e03..240331bb 100644
--- a/core/io/writer.py
+++ b/core/io/writer.py
@@ -1,3 +1,5 @@
+from itertools import accumulate
+
 import h5py
 import pandas as pd
 
@@ -5,15 +7,44 @@ from postprocessor.core.io.base import BridgeH5
 
 
 def Writer(BridgeH5):
-    def __init__(self, filename):
-        self._hdf = h5py.File(filename, "a")
+    """
+    Class in charge of transforming data into compatible formats
+
+    Decoupling interface from implementation!
+
+    :hdfname: Name of file to write into
+    """
+
+    def __init__(self, hdfname):
+        self._hdf = h5py.Hdf(hdfname, "a")
 
     def write(self, address, data):
-        self._file.add_group(address)
+        self._hdf.add_group(address)
         if type(data) is pd.DataFrame:
             self.write_df(address, data)
         elif type(data) is np.array:
             self.write_np(address, data)
 
-    def write_df(self, adress, df):
-        self._file.get(address)[()] = data
+    def write_np(self, address, array):
+        pass
+
+    def write_df(self, df, tps, path):
+        print("writing to ", path)
+        for item in accummulate(path.split("/")[:-2]):
+            if item not in self._hdf:
+                self._hdf.create_group(item)
+        pos_group = f[path.split("/")[1]]
+
+        if path not in pos_group:
+            pos_group.create_dataset(name=path, shape=df.shape, dtype=df.dtypes[0])
+            new_dset = f[path]
+            new_dset[()] = df.values
+            if len(df.index.names) > 1:
+                trap, cell_label = zip(*list(df.index.values))
+                new_dset.attrs["trap"] = trap
+                new_dset.attrs["cell_label"] = cell_label
+                new_dset.attrs["idnames"] = ["trap", "cell_label"]
+            else:
+                new_dset.attrs["trap"] = list(df.index.values)
+                new_dset.attrs["idnames"] = ["trap"]
+        pos_group.attrs["processed_timepoints"] = tps
-- 
GitLab