diff --git a/src/wela/dataloader.py b/src/wela/dataloader.py index 85a6679c4ad1d3e0a64ff42cf4f31f65e301540b..b19d72f7846468882e028225a9ad2edecf08c1a8 100644 --- a/src/wela/dataloader.py +++ b/src/wela/dataloader.py @@ -598,6 +598,7 @@ class dataloader: duration_threshold=None, tmin=None, tmax=None, + group=None, ): """ Find a sub data frame of dataloader's main data frame. @@ -615,8 +616,13 @@ class dataloader: Only include data for times greater than tmin tmax: float (optional) Only include data for times less than tmax + group: str (optional) + Group to specialise to. """ - sdf = self.df + if group is None: + sdf = self.df + else: + sdf = self.df[self.df.group == group] selected_ids = [] # drop signals that are all NaN if dropna and signal: