.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_compile_models_from_tensorflow.py: Compile Tensorflow Models ========================= This article is an introductory tutorial to deploy tensorflow models with TVM. For us to begin with, tensorflow python module is required to be installed. Please refer to https://www.tensorflow.org/install .. code-block:: default # tvm, relay import tvm from tvm import te from tvm import relay # os and numpy import numpy as np import os.path # Tensorflow imports import tensorflow as tf # Ask tensorflow to limit its GPU memory to what's actually needed # instead of gobbling everything that's available. # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth # This way this tutorial is a little more friendly to sphinx-gallery. gpus = tf.config.list_physical_devices("GPU") if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) print("tensorflow will use experimental.set_memory_growth(True)") except RuntimeError as e: print("experimental.set_memory_growth option is not available: {}".format(e)) try: tf_compat_v1 = tf.compat.v1 except ImportError: tf_compat_v1 = tf # Tensorflow utility functions import tvm.relay.testing.tf as tf_testing # Base location for model related files. repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/" # Test image img_name = "elephant-299.jpg" image_url = os.path.join(repo_base, img_name) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none tensorflow will use experimental.set_memory_growth(True) Tutorials --------- Please refer docs/frontend/tensorflow.md for more details for various models from tensorflow. .. code-block:: default model_name = "classify_image_graph_def-with_shapes.pb" model_url = os.path.join(repo_base, model_name) # Image label map map_proto = "imagenet_2012_challenge_label_map_proto.pbtxt" map_proto_url = os.path.join(repo_base, map_proto) # Human readable text for labels label_map = "imagenet_synset_to_human_label_map.txt" label_map_url = os.path.join(repo_base, label_map) # Target settings # Use these commented settings to build for cuda. # target = tvm.target.Target("cuda", host="llvm") # layout = "NCHW" # dev = tvm.cuda(0) target = tvm.target.Target("llvm", host="llvm") layout = None dev = tvm.cpu(0) Download required files ----------------------- Download files listed above. .. code-block:: default from tvm.contrib.download import download_testdata img_path = download_testdata(image_url, img_name, module="data") model_path = download_testdata(model_url, model_name, module=["tf", "InceptionV1"]) map_proto_path = download_testdata(map_proto_url, map_proto, module="data") label_path = download_testdata(label_map_url, label_map, module="data") Import model ------------ Creates tensorflow graph definition from protobuf file. .. code-block:: default with tf_compat_v1.gfile.GFile(model_path, "rb") as f: graph_def = tf_compat_v1.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name="") # Call the utility to import the graph definition into default graph. graph_def = tf_testing.ProcessGraphDefParam(graph_def) # Add shapes to the graph. with tf_compat_v1.Session() as sess: graph_def = tf_testing.AddShapesToGraphDef(sess, "softmax") Decode image ------------ .. note:: tensorflow frontend import doesn't support preprocessing ops like JpegDecode. JpegDecode is bypassed (just return source node). Hence we supply decoded frame to TVM instead. .. code-block:: default from PIL import Image image = Image.open(img_path).resize((299, 299)) x = np.array(image) Import the graph to Relay ------------------------- Import tensorflow graph definition to relay frontend. Results: sym: relay expr for given tensorflow protobuf. params: params converted from tensorflow params (tensor protobuf). .. code-block:: default shape_dict = {"DecodeJpeg/contents": x.shape} dtype_dict = {"DecodeJpeg/contents": "uint8"} mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict) print("Tensorflow protobuf imported to relay frontend.") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /workspace/python/tvm/relay/frontend/tensorflow.py:535: UserWarning: Ignore the passed shape. Shape in graphdef will be used for operator DecodeJpeg/contents. "will be used for operator %s." % node.name /workspace/python/tvm/relay/frontend/tensorflow_ops.py:1006: UserWarning: DecodeJpeg: It's a pass through, please handle preprocessing before input warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input") Tensorflow protobuf imported to relay frontend. Relay Build ----------- Compile the graph to llvm target with given input specification. Results: graph: Final graph after compilation. params: final params after compilation. lib: target library which can be deployed on target with TVM runtime. .. code-block:: default with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target, params=params) Execute the portable graph on TVM --------------------------------- Now we can try deploying the compiled model on target. .. code-block:: default from tvm.contrib import graph_executor dtype = "uint8" m = graph_executor.GraphModule(lib["default"](dev)) # set inputs m.set_input("DecodeJpeg/contents", tvm.nd.array(x.astype(dtype))) # execute m.run() # get outputs tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), "float32")) Process the output ------------------ Process the model output to human readable text for InceptionV1. .. code-block:: default predictions = tvm_output.numpy() predictions = np.squeeze(predictions) # Creates node ID --> English string lookup. node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path, uid_lookup_path=label_path) # Print top 5 predictions from TVM output. top_k = predictions.argsort()[-5:][::-1] for node_id in top_k: human_string = node_lookup.id_to_string(node_id) score = predictions[node_id] print("%s (score = %.5f)" % (human_string, score)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none African elephant, Loxodonta africana (score = 0.58335) tusker (score = 0.33901) Indian elephant, Elephas maximus (score = 0.02391) banana (score = 0.00025) vault (score = 0.00021) Inference on tensorflow ----------------------- Run the corresponding model on tensorflow .. code-block:: default def create_graph(): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. with tf_compat_v1.gfile.GFile(model_path, "rb") as f: graph_def = tf_compat_v1.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name="") # Call the utility to import the graph definition into default graph. graph_def = tf_testing.ProcessGraphDefParam(graph_def) def run_inference_on_image(image): """Runs inference on an image. Parameters ---------- image: String Image file name. Returns ------- Nothing """ if not tf_compat_v1.gfile.Exists(image): tf.logging.fatal("File does not exist %s", image) image_data = tf_compat_v1.gfile.GFile(image, "rb").read() # Creates graph from saved GraphDef. create_graph() with tf_compat_v1.Session() as sess: softmax_tensor = sess.graph.get_tensor_by_name("softmax:0") predictions = sess.run(softmax_tensor, {"DecodeJpeg/contents:0": image_data}) predictions = np.squeeze(predictions) # Creates node ID --> English string lookup. node_lookup = tf_testing.NodeLookup( label_lookup_path=map_proto_path, uid_lookup_path=label_path ) # Print top 5 predictions from tensorflow. top_k = predictions.argsort()[-5:][::-1] print("===== TENSORFLOW RESULTS =======") for node_id in top_k: human_string = node_lookup.id_to_string(node_id) score = predictions[node_id] print("%s (score = %.5f)" % (human_string, score)) run_inference_on_image(img_path) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none ===== TENSORFLOW RESULTS ======= African elephant, Loxodonta africana (score = 0.58394) tusker (score = 0.33909) Indian elephant, Elephas maximus (score = 0.03186) banana (score = 0.00022) desk (score = 0.00019) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 21.217 seconds) .. _sphx_glr_download_how_to_compile_models_from_tensorflow.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: from_tensorflow.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: from_tensorflow.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_