46 lines
1.5 KiB
Python
Executable File
46 lines
1.5 KiB
Python
Executable File
#!/usr/bin/env python
|
|
"""Extracts trainable parameters from Caffe models and stores them in numpy arrays.
|
|
Usage
|
|
python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist
|
|
|
|
Saves each variable to a {variable_name}.npy binary file.
|
|
|
|
Tested with Caffe 1.0 on Python 2.7
|
|
"""
|
|
import argparse
|
|
import caffe
|
|
import os
|
|
import numpy as np
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Parse arguments
|
|
parser = argparse.ArgumentParser('Extract Caffe net parameters')
|
|
parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file')
|
|
parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist')
|
|
args = parser.parse_args()
|
|
|
|
# Create Caffe Net
|
|
net = caffe.Net(args.netFile, 1, weights=args.modelFile)
|
|
|
|
# Read and dump blobs
|
|
for name, blobs in net.params.iteritems():
|
|
print('Name: {0}, Blobs: {1}'.format(name, len(blobs)))
|
|
for i in range(len(blobs)):
|
|
# Weights
|
|
if i == 0:
|
|
outname = name + "_w"
|
|
# Bias
|
|
elif i == 1:
|
|
outname = name + "_b"
|
|
else:
|
|
continue
|
|
|
|
varname = outname
|
|
if os.path.sep in varname:
|
|
varname = varname.replace(os.path.sep, '_')
|
|
print("Renaming variable {0} to {1}".format(outname, varname))
|
|
print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape))
|
|
# Dump as binary
|
|
np.save(varname, blobs[i].data)
|