Skip to content
Snippets Groups Projects
Commit 66f45f2f authored by fconforto's avatar fconforto
Browse files

Added normalisation layer and test option

parent 0fcc5f6b
Branches main
No related tags found
No related merge requests found
......@@ -80,6 +80,14 @@ def getParams():
"-e", "--epochs", type=int, default=1000, help="Set the number of training epochs"
)
par.add_argument(
"-m",
"--mode",
type=str,
default="train",
help="mode: train or test",
)
args = par.parse_args()
return args
......@@ -21,6 +21,7 @@ def main(argv):
net = args.network
epochs = args.epochs
knots = set_constants(prob)
mode = args.mode
Nbeads = int(args.nbeads)
if dtype == "XYZ":
......@@ -39,7 +40,7 @@ def main(argv):
dataset = dataset.shuffle(buffer_size=1, seed=42)
ninputs = len(knots)*100000
train_dataset, test_dataset, val_dataset = split_train_test_validation(dataset,int(ninputs*(0.7)),int(ninputs*(0.1)),int(ninputs*(0.2)))
train_dataset, test_dataset, val_dataset = split_train_test_validation(dataset,int(ninputs*(0.9)),int(ninputs*(0.075)),int(ninputs*(0.025)))
if net == "FFNN":
model = setup_NN(
......@@ -62,6 +63,17 @@ def main(argv):
else:
raise NameError
if mode=="train":
train(model, train_dataset, val_dataset, args, 8500)
elif mode=="test":
test(test_dataset, args)
def train(model,train_dataset,val_dataset,args, bs):
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(bs)
val_dataset = val_dataset.batch(bs)
# Early Stopping Callback
es = tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
......@@ -76,18 +88,19 @@ def main(argv):
checkpoint_filepath = (
"NN_model_best"
+ "_"
+ str(dtype)
+ str(args.datatype)
+ "_"
+ str(prob)
+ str(args.problem)
+ "_Adj_"
+ str(adj)
+ str(args.adjacent)
+ "_Norm_"
+ str(norm)
+ str(args.normalised)
+ "_Net_"
+ str(net)
+ str(args.network)
+ "Nbeads"
+ str(Nbeads)
+ str(args.nbeads)
)
mc = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=False,
......@@ -97,20 +110,40 @@ def main(argv):
)
cb_list = [es, mc]
bs = 8500
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(bs)
val_dataset = val_dataset.batch(bs)
model.fit(
train_dataset,
steps_per_epoch=50,
epochs=epochs,
epochs=args.epochs,
verbose=1,
callbacks=cb_list,
validation_data = val_dataset
)
def test(test_dataset,args):
checkpoint_filepath = (
"NN_model_best"
+ "_"
+ str(args.datatype)
+ "_"
+ str(args.problem)
+ "_Adj_"
+ str(args.adjacent)
+ "_Norm_"
+ str(args.normalised)
+ "_Net_"
+ str(args.network)
+ "Nbeads"
+ str(args.nbeads)
)
model = tf.saved_model.load(checkpoint_filepath)
test_labels = test_dataset.map(lambda x,y: y)
test_dataset = test_dataset.map(lambda x,y: x)
predictions = model.predict(test_dataset)
print(tf.math.confusion_matrix(test_labels, predictions))
if __name__ == "__main__":
main(sys.argv[1:])
......@@ -11,6 +11,9 @@ def setup_RNN(RNN_hidden_top, input_shape, output_shape, hidden_activation, opt,
return_sequences=True,
recurrent_dropout=0))
if norm:
model.add(tf.keras.layers.BatchNormalization())
# add bidirectional LSTM layer
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(RNN_hidden_top[1], activation=hidden_activation, return_sequences=True, recurrent_dropout = 0)))
......@@ -53,6 +56,9 @@ def setup_NN(NN_hidden_top, input_shape, output_shape, hidden_activation, opt, n
# kernel_initializer = tf.keras.initializers.RandomUniform(seed=None),
activation=hidden_activation))
if norm:
model.add(tf.keras.layers.BatchNormalization())
# add hidden layers to NN
for i in range(len(NN_hidden_top)-1):
model.add(tf.keras.layers.Dense(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment