Skip to content
Snippets Groups Projects
Commit 375450da authored by fconforto's avatar fconforto
Browse files

Fixed loading script

parent 6cb9f929
Branches main
No related tags found
No related merge requests found
import tensorflow as tf
def load_dataset(filename, label):
# Loading the dataset file
dataset = tf.data.experimental.make_csv_dataset(filename, 200,
select_columns=[2],column_defaults = [ tf.float32, ],num_epochs=100000,
header=True, field_delim=" ")
dataset = tf.data.experimental.CsvDataset(filename, [tf.float32, ], select_cols=[2], header=True, field_delim=" ")
# Reshape the incoming data
dataset = dataset.batch(200)
# Create labelled vector
dataset = dataset.map(lambda x: (x.values(),label))
dataset = dataset.map(lambda x: (tf.reshape(x,(200,1)),label))
return dataset
......
......@@ -29,14 +29,14 @@ def main(argv):
in_layer = (Nbeads,1)
datasets = []
for knot in knots:
datasets.append(load_dataset("/storage/cmstore02/groups/TAPLab/knots_database/" + str(knot) + "/N200/lp10/SIGWRITHE/3DSignedWrithe_" + knot + ".dat.lp10.dat.nos", knot))
for i, knot in enumerate(knots):
datasets.append(load_dataset("/storage/cmstore02/groups/TAPLab/knots_database/" + str(knot) + "/N200/lp10/SIGWRITHE/3DSignedWrithe_" + knot + ".dat.lp10.dat.nos", i))
dataset = combine_datasets(datasets)
dataset = tf.data.experimental.sample_from_datasets(datasets)
for element in dataset.as_numpy_iterator():
print(element)
break
dataset = dataset.shuffle(buffer_size=1)
dataset = dataset.repeat()
if net == "FFNN":
model = setup_NN(
......@@ -59,7 +59,7 @@ def main(argv):
# Early Stopping Callback
es = tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
monitor="loss",
mode="min",
verbose=1,
patience=10,
......@@ -84,20 +84,25 @@ def main(argv):
mc = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=False,
monitor="val_loss",
monitor="loss",
mode="min",
save_best_only=True,
)
cb_list = [es, mc]
bs = 256
dataset = dataset.batch(bs)
for i,element in enumerate(dataset.as_numpy_iterator()):
print(element)
if i>=1:
break
model.fit(
dataset,
epochs=1,
batch_size=bs,
verbose=0,
# validation_split=1 / 5,
steps_per_epoch=300,
epochs=20,
verbose=1,
callbacks=cb_list,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment