Newer
Older
def load_dataset(filename, label):
Joseph Sleiman
committed
dataset = tf.data.TextLineDataset(filename, num_parallel_reads=5)
dataset.filter(filter_function)
# Apply eventual preprocessing steps
dataset = dataset.map(preproc)
Joseph Sleiman
committed
dataset = dataset.shuffle(buffer_size)
dataset = dataset.batch(batch_size)
# Create label vector
labels = tf.fill(tf.data.experimental.cardinality(dataset), label)
# Join the label vector
dataset = tf.Dataset.zip((dataset, labels))
return dataset
def combine_datasets(datasets):
# Combining the datasets
c_dataset = datasets[0].concatenate(datasets[1])
c_dataset = c_dataset.concatenate(dataset)
def split_train_test_validation(dataset, train_size, test_size, val_size):
train_dataset = dataset.take(train_size)
test_dataset = dataset.take(test_size)
val_dataset = dataset.take(val_size)
Joseph Sleiman
committed
def filter_function(line):
return tf.not_equal(tf.strings.contains(line, '#'))
def set_shape(x):
x.set_shape((Nbeads,dimensions))
return x
def pad_sequences(x):
x = tf.keras.utils.pad_sequences(x, maxlen=1000, value=-100)
return x