aboutsummaryrefslogtreecommitdiff
path: root/training/dump_rnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/dump_rnn.py')
-rwxr-xr-xtraining/dump_rnn.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/training/dump_rnn.py b/training/dump_rnn.py
index 9652c81..964a5fa 100755
--- a/training/dump_rnn.py
+++ b/training/dump_rnn.py
@@ -30,20 +30,24 @@ def printVector(f, vector, name):
f.write('\n};\n\n')
return;
-def printLayer(f, layer):
+def printLayer(f, hf, layer):
weights = layer.get_weights()
printVector(f, weights[0], layer.name + '_weights')
if len(weights) > 2:
printVector(f, weights[1], layer.name + '_recurrent_weights')
printVector(f, weights[-1], layer.name + '_bias')
name = layer.name
- activation = re.search('function (.*) at', str(layer.activation)).group(1)
+ activation = re.search('function (.*) at', str(layer.activation)).group(1).upper()
if len(weights) > 2:
- f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, activation_{}\n}};\n\n'
+ f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n'
.format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation))
+ hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]/3))
+ hf.write('extern const GRULayer {};\n\n'.format(name));
else:
- f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, activation_{}\n}};\n\n'
+ f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n'
.format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation))
+ hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]))
+ hf.write('extern const DenseLayer {};\n\n'.format(name));
def mean_squared_sqrt_error(y_true, y_pred):
@@ -55,13 +59,19 @@ model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error,
weights = model.get_weights()
f = open(sys.argv[2], 'w')
+hf = open(sys.argv[3], 'w')
f.write('/*This file is automatically generated from a Keras model*/\n\n')
f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n')
+hf.write('/*This file is automatically generated from a Keras model*/\n\n')
+hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "rnn.h"\n\n')
+
for i, layer in enumerate(model.layers):
if len(layer.get_weights()) > 0:
- printLayer(f, layer)
+ printLayer(f, hf, layer)
+hf.write('\n\n#endif\n')
f.close()
+hf.close()