60 lines
2.3 KiB
Python
60 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
"""Extracts mnist image data from the Caffe data files and stores them in numpy arrays
|
|
Usage
|
|
python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path
|
|
|
|
Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the
|
|
corresponding labels to labels100.txt.
|
|
|
|
Tested with Caffe 1.0 on Python 2.7
|
|
"""
|
|
import argparse
|
|
import os
|
|
import struct
|
|
import numpy as np
|
|
from array import array
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Parse arguments
|
|
parser = argparse.ArgumentParser('Extract Caffe mnist image data')
|
|
parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory')
|
|
parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)')
|
|
args = parser.parse_args()
|
|
|
|
images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte')
|
|
labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte')
|
|
|
|
images_file = open(images_filename, 'rb')
|
|
labels_file = open(labels_filename, 'rb')
|
|
images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16))
|
|
labels_magic, labels_size = struct.unpack('>II', labels_file.read(8))
|
|
images = array('B', images_file.read())
|
|
labels = array('b', labels_file.read())
|
|
|
|
input10_path = os.path.join(args.outDir, 'input10.npy')
|
|
input100_path = os.path.join(args.outDir, 'input100.npy')
|
|
labels100_path = os.path.join(args.outDir, 'labels100.npy')
|
|
|
|
outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32)
|
|
outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32)
|
|
labels_output = open(labels100_path, 'w')
|
|
for i in xrange(100):
|
|
image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0
|
|
outputs_100[i, :, :, 0] = image
|
|
|
|
if i < 10:
|
|
outputs_10[i, :, :, 0] = image
|
|
|
|
if i == 10:
|
|
np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2)))
|
|
print "Wrote", input10_path
|
|
|
|
labels_output.write(str(labels[i]) + '\n')
|
|
|
|
labels_output.close()
|
|
print "Wrote", labels100_path
|
|
|
|
np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2)))
|
|
print "Wrote", input100_path
|