225 lines
7.4 KiB
Python
225 lines
7.4 KiB
Python
"""Setup TensorFlow as external dependency"""
|
|
|
|
_TF_HEADER_DIR = "TF_HEADER_DIR"
|
|
_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
|
|
_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
|
|
|
|
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
|
|
if not out:
|
|
out = tpl
|
|
repository_ctx.template(
|
|
out,
|
|
Label("//third_party/tensorflow:%s.tpl" % tpl),
|
|
substitutions,
|
|
)
|
|
|
|
def _fail(msg):
|
|
"""Output failure message when auto configuration fails."""
|
|
red = "\033[0;31m"
|
|
no_color = "\033[0m"
|
|
fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
|
|
|
|
def _is_windows(repository_ctx):
|
|
"""Returns true if the host operating system is windows."""
|
|
os_name = repository_ctx.os.name.lower()
|
|
if os_name.find("windows") != -1:
|
|
return True
|
|
return False
|
|
|
|
def _execute(
|
|
repository_ctx,
|
|
cmdline,
|
|
error_msg = None,
|
|
error_details = None,
|
|
empty_stdout_fine = False):
|
|
"""Executes an arbitrary shell command.
|
|
|
|
Helper for executes an arbitrary shell command.
|
|
|
|
Args:
|
|
repository_ctx: the repository_ctx object.
|
|
cmdline: list of strings, the command to execute.
|
|
error_msg: string, a summary of the error if the command fails.
|
|
error_details: string, details about the error or steps to fix it.
|
|
empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
|
|
it's an error.
|
|
|
|
Returns:
|
|
The result of repository_ctx.execute(cmdline).
|
|
"""
|
|
result = repository_ctx.execute(cmdline)
|
|
if result.stderr or not (empty_stdout_fine or result.stdout):
|
|
_fail("\n".join([
|
|
error_msg.strip() if error_msg else "Repository command failed",
|
|
result.stderr.strip(),
|
|
error_details if error_details else "",
|
|
]))
|
|
return result
|
|
|
|
def _read_dir(repository_ctx, src_dir):
|
|
"""Returns a string with all files in a directory.
|
|
|
|
Finds all files inside a directory, traversing subfolders and following
|
|
symlinks. The returned string contains the full path of all files
|
|
separated by line breaks.
|
|
|
|
Args:
|
|
repository_ctx: the repository_ctx object.
|
|
src_dir: directory to find files from.
|
|
|
|
Returns:
|
|
A string of all files inside the given dir.
|
|
"""
|
|
if _is_windows(repository_ctx):
|
|
src_dir = src_dir.replace("/", "\\")
|
|
find_result = _execute(
|
|
repository_ctx,
|
|
["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
|
|
empty_stdout_fine = True,
|
|
)
|
|
|
|
# src_files will be used in genrule.outs where the paths must
|
|
# use forward slashes.
|
|
result = find_result.stdout.replace("\\", "/")
|
|
else:
|
|
find_result = _execute(
|
|
repository_ctx,
|
|
["find", src_dir, "-follow", "-type", "f"],
|
|
empty_stdout_fine = True,
|
|
)
|
|
result = find_result.stdout
|
|
return result
|
|
|
|
def _genrule(genrule_name, command, outs):
|
|
"""Returns a string with a genrule.
|
|
|
|
Genrule executes the given command and produces the given outputs.
|
|
|
|
Args:
|
|
genrule_name: A unique name for genrule target.
|
|
command: The command to run.
|
|
outs: A list of files generated by this rule.
|
|
|
|
Returns:
|
|
A genrule target.
|
|
"""
|
|
return (
|
|
"genrule(\n" +
|
|
' name = "' +
|
|
genrule_name + '",\n' +
|
|
" outs = [\n" +
|
|
outs +
|
|
"\n ],\n" +
|
|
' cmd = """\n' +
|
|
command +
|
|
'\n """,\n' +
|
|
")\n"
|
|
)
|
|
|
|
def _norm_path(path):
|
|
"""Returns a path with '/' and remove the trailing slash."""
|
|
path = path.replace("\\", "/")
|
|
if path[-1] == "/":
|
|
path = path[:-1]
|
|
return path
|
|
|
|
def _symlink_genrule_for_dir(
|
|
repository_ctx,
|
|
src_dir,
|
|
dest_dir,
|
|
genrule_name,
|
|
src_files = [],
|
|
dest_files = [],
|
|
tf_pip_dir_rename_pair = []):
|
|
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
|
|
|
|
If src_dir is passed, files will be read from the given directory; otherwise
|
|
we assume files are in src_files and dest_files.
|
|
|
|
Args:
|
|
repository_ctx: the repository_ctx object.
|
|
src_dir: source directory.
|
|
dest_dir: directory to create symlink in.
|
|
genrule_name: genrule name.
|
|
src_files: list of source files instead of src_dir.
|
|
dest_files: list of corresonding destination files.
|
|
tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
|
|
replace. For example, in TF pip package, the source code is under
|
|
"tensorflow_core", and we might want to replace it with
|
|
"tensorflow" to match the header includes.
|
|
Returns:
|
|
genrule target that creates the symlinks.
|
|
"""
|
|
|
|
# Check that tf_pip_dir_rename_pair has the right length
|
|
tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
|
|
if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
|
|
_fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)
|
|
|
|
if src_dir != None:
|
|
src_dir = _norm_path(src_dir)
|
|
dest_dir = _norm_path(dest_dir)
|
|
files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
|
|
|
|
# Create a list with the src_dir stripped to use for outputs.
|
|
if tf_pip_dir_rename_pair_len:
|
|
dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
|
|
else:
|
|
dest_files = files.replace(src_dir, "").splitlines()
|
|
src_files = files.splitlines()
|
|
command = []
|
|
outs = []
|
|
for i in range(len(dest_files)):
|
|
if dest_files[i] != "":
|
|
# If we have only one file to link we do not want to use the dest_dir, as
|
|
# $(@D) will include the full path to the file.
|
|
dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
|
|
|
|
# Copy the headers to create a sandboxable setup.
|
|
cmd = "cp -f"
|
|
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
|
|
outs.append(' "' + dest_dir + dest_files[i] + '",')
|
|
dest_dir = "abc"
|
|
genrule = _genrule(
|
|
genrule_name,
|
|
" && ".join(command),
|
|
"\n".join(outs),
|
|
)
|
|
return genrule
|
|
|
|
def _tf_pip_impl(repository_ctx):
|
|
tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
|
|
tf_header_rule = _symlink_genrule_for_dir(
|
|
repository_ctx,
|
|
tf_header_dir,
|
|
"include",
|
|
"tf_header_include",
|
|
tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
|
|
)
|
|
|
|
tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
|
|
tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
|
|
tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
|
|
tf_shared_library_rule = _symlink_genrule_for_dir(
|
|
repository_ctx,
|
|
None,
|
|
"",
|
|
"libtensorflow_framework.so",
|
|
[tf_shared_library_path],
|
|
["_pywrap_tensorflow_internal.lib" if _is_windows(repository_ctx) else "libtensorflow_framework.so"],
|
|
)
|
|
|
|
_tpl(repository_ctx, "BUILD", {
|
|
"%{TF_HEADER_GENRULE}": tf_header_rule,
|
|
"%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
|
|
})
|
|
|
|
tf_configure = repository_rule(
|
|
implementation = _tf_pip_impl,
|
|
environ = [
|
|
_TF_HEADER_DIR,
|
|
_TF_SHARED_LIBRARY_DIR,
|
|
_TF_SHARED_LIBRARY_NAME,
|
|
],
|
|
)
|