diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py index cdf0859b61d7098679c404cc84647d411e31d248..56123a64c2cc8a964f08cc5f393a8a0b2918c05e 100644 --- a/src/postprocessor/chainer.py +++ b/src/postprocessor/chainer.py @@ -33,10 +33,10 @@ class Chainer(Signal): ): data = self.get_raw(dataset, in_minutes=in_minutes) if chain: - data = self.apply_chain(data, **kwargs) + data = self.apply_chain(data, chain, **kwargs) return data - def chain( + def apply_chain( self, input_data: pd.DataFrame, chain: t.Tuple[str, ...], **kwargs ): """Apply a series of processes to a dataset. @@ -61,29 +61,23 @@ class Chainer(Signal): """ - results = copy(input_data) + result = copy(input_data) self._intermediate_steps = [] for process in chain: - params = kwargs.get(process, {}) - process_cls = get_process(process) - result = process_cls.as_function(results, **params) - process_type = process_cls.__module__.split(".")[-2] - if process_type == "reshapers": - self.prepare_step(process_type) - if process == "merger": - merges = process.as_function(results, **params) - results = self.apply_merges(result, merges) - self._intermediate_steps.append(result) - return results - - def prepare_step( - self, data: pd.DataFrame, step: str - ) -> t.Tuple[t.Callable, pd.DataFrame]: - pass + if process == "standard": + result = standard(result, self.lineage()) + else: + params = kwargs.get(process, {}) + process_cls = get_process(process) + result = process_cls.as_function(result, **params) + process_type = process_cls.__module__.split(".")[-2] + if process_type == "reshapers": + if process == "merger": + merges = process.as_function(result, **params) + result = self.apply_merges(result, merges) - def standard_processing(self, url: str): - raw = self.get_raw(url) - st = standard(raw, self.lineage()) + self._intermediate_steps.append(result) + return result def standard(