diff options
Diffstat (limited to 'training/dump_rnn.py')
-rwxr-xr-x | training/dump_rnn.py | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/training/dump_rnn.py b/training/dump_rnn.py index 9652c81..964a5fa 100755 --- a/training/dump_rnn.py +++ b/training/dump_rnn.py @@ -30,20 +30,24 @@ def printVector(f, vector, name): f.write('\n};\n\n') return; -def printLayer(f, layer): +def printLayer(f, hf, layer): weights = layer.get_weights() printVector(f, weights[0], layer.name + '_weights') if len(weights) > 2: printVector(f, weights[1], layer.name + '_recurrent_weights') printVector(f, weights[-1], layer.name + '_bias') name = layer.name - activation = re.search('function (.*) at', str(layer.activation)).group(1) + activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() if len(weights) > 2: - f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, activation_{}\n}};\n\n' + f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) + hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]/3)) + hf.write('extern const GRULayer {};\n\n'.format(name)); else: - f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, activation_{}\n}};\n\n' + f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) + hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1])) + hf.write('extern const DenseLayer {};\n\n'.format(name)); def mean_squared_sqrt_error(y_true, y_pred): @@ -55,13 +59,19 @@ model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, weights = model.get_weights() f = open(sys.argv[2], 'w') +hf = open(sys.argv[3], 'w') f.write('/*This file is automatically generated from a Keras model*/\n\n') f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n') +hf.write('/*This file is automatically generated from a Keras model*/\n\n') +hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "rnn.h"\n\n') + for i, layer in enumerate(model.layers): if len(layer.get_weights()) > 0: - printLayer(f, layer) + printLayer(f, hf, layer) +hf.write('\n\n#endif\n') f.close() +hf.close() |