Skip to content
Snippets Groups Projects
Commit 3891fc94 authored by Joseph Sleiman's avatar Joseph Sleiman
Browse files

Merge remote-tracking branch 'origin/master'

parents 9dd1f995 6cb9f929
Branches main
No related tags found
No related merge requests found
import tensorflow as tf
def load_dataset(filename, preproc, label):
def load_dataset(filename, label):
# Loading the dataset file
dataset = tf.data.TextLineDataset(filename, num_parallel_reads=5)
......@@ -26,12 +26,13 @@ def load_dataset(filename, preproc, label):
def combine_datasets(datasets):
# Combining the datasets
combined_dataset = datasets[0].concatenate(datasets[1])
c_dataset = datasets[0].concatenate(datasets[1])
for dataset in datasets[2:]:
combined_dataset = combined_dataset.concatenate(dataset)
c_dataset = c_dataset.concatenate(dataset)
return c_dataset
def split_train_test_validation(dataset, train_size, test_size, val_size):
train_dataset = dataset.take(train_size)
test_dataset = dataset.take(test_size)
......
......@@ -18,21 +18,26 @@ def main(argv):
dtype = args.datatype
adj = args.adjacent
norm = args.normalised
net = args.normalised
net = args.network
epochs = args.epochs
knots = set_constants(prob)
Nbeads = 200
if dtype == "XYZ":
in_layer = (Nbeads, 3)
else:
in_layer = (Nbeads,1)
datasets = []
for knot in knots:
datasets.append(load_dataset(knot))
datasets.append(load_dataset("/storage/cmstore02/groups/TAPLab/knots_database/" + str(knot) + "/N200/lp10/SIGWRITHE/3DSignedWrithe_" + knot + ".dat.lp10.dat.nos", knot))
dataset = combine_datasets(datasets)
Nbeads = 200
if dtype == "XYZ":
in_layer = (Nbeads, 3)
else:
in_layer = (Nbeads,)
for element in dataset.as_numpy_iterator():
print(element)
break
if net == "FFNN":
model = setup_NN(
[320, 320, 320, 320],
......@@ -87,14 +92,14 @@ def main(argv):
cb_list = [es, mc]
bs = 256
# model.fit(
# dataset,
# epochs=epochs,
# batch_size=bs,
# verbose=1,
# validation_split=1 / 5,
# callbacks=cb_list,
# )
model.fit(
dataset,
epochs=1,
batch_size=bs,
verbose=0,
# validation_split=1 / 5,
callbacks=cb_list,
)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment