Skip to content
Snippets Groups Projects
Commit ea27c2ad authored by fconforto's avatar fconforto
Browse files

Added shap

parent 6636e643
Branches main
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
import tensorflow as tf
import keras_tuner as kt
from keras import backend as K
import shap
import numpy as np
import matplotlib.pyplot as plt
......@@ -154,7 +155,7 @@ def plot_importance(test_dataset):
label_dataset = test_dataset.map(lambda x, y: y)
test_dataset = test_dataset.map(lambda x, y: x)
el_num = (np.random.rand(4) * len_db * 0.075 * len(knots)).astype(int)
el_num = (np.random.rand(50) * len_db * 0.075 * len(knots)).astype(int)
input_sequences = []
input_labels = []
......@@ -163,7 +164,17 @@ def plot_importance(test_dataset):
input_sequences.append(x_element.numpy().reshape(1,-1))
input_labels.append(y_element)
input_sequences = np.array(input_sequences)
explainer = shap.DeepExplainer(model, input_sequences.transpose(0,2,1))
for j, (input_sequence, input_label) in enumerate(zip(input_sequences, input_labels)):
plt.plot(explainer.shap_values(input_sequence.transpose(1,0)[None,:,:])[input_label].flatten())
plt.savefig(os.path.join(checkpoint_filepath,"layer_analysis",f"activation_gradient_label_{input_label}_test_{el_num[j]}.pdf"))
plt.cla()
for i in range(len(model.layers)):
intermediate_model=tf.keras.models.Model(inputs=model.input,outputs=model.layers[i].output) #Intermediate model between Input Layer and Output Layer which we are concerned about
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment