3.3 - Using tf.function

!wget -nc --no-cache -O init.py -q https://raw.githubusercontent.com/rramosp/2021.deeplearning/main/content/init.py
import init; init.init(force_download=False); 
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
%load_ext tensorboard

from sklearn.datasets import *
from local.lib import mlutils
tf.__version__
'2.4.1'

tf.function automatically converts pythonic code to a computational graph, using Tensors

def f(x):
    return x**2 + x*3
f(2)
10
@tf.function
def f(x):
    return x**2 + x*3
f(2)
<tf.Tensor: shape=(), dtype=int32, numpy=10>

and also works with a symbolic tensor

x = tf.Variable(3.)
f(x)
<tf.Tensor: shape=(), dtype=float32, numpy=18.0>

a tf.function is traced (converted to computation graph) the first time it is executed, then it is cached IF IT IS REUSED WITH THE SAME TF VARIABLES

@tf.function
def f47(t):
    print('Tracing!')
    tf.print('Executing')  
    return t**2 + t*47
f47(2)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=98>
f47(2)
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=98>

observe that if the type changes, the function is traced again since a different computational graph must be created

x = tf.Variable(2, dtype=tf.float32)
f47(x)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
x.assign(3.4)
f47(x)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=171.36>

tracing happens for EACH VARIABLE

y = tf.Variable(2, dtype=tf.float32)
f47(y)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
f47(y)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
f47(x)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=171.36>
x = tf.Variable(3, dtype=tf.int32)
f47(x)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=150>
x.assign(9)
f47(x)
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=504>
print (f47.pretty_printed_concrete_signatures())
f47(t)
  Args:
    t: VariableSpec(shape=(), dtype=tf.float32, name='t')
  Returns:
    float32 Tensor, shape=()

f47(t=2)
  Returns:
    int32 Tensor, shape=()

f47(t)
  Args:
    t: VariableSpec(shape=(), dtype=tf.float32, name='t')
  Returns:
    float32 Tensor, shape=()

f47(t)
  Args:
    t: VariableSpec(shape=(), dtype=tf.int32, name='t')
  Returns:
    int32 Tensor, shape=()

observe the actual generated code by tf.autograph

print(tf.autograph.to_code(f47.python_function))
def tf__f47(t):
    with ag__.FunctionScope('f47', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()
        ag__.ld(print)('Tracing!')
        ag__.converted_call(ag__.ld(tf).print, ('Executing',), None, fscope)
        try:
            do_return = True
            retval_ = ((ag__.ld(t) ** 2) + (ag__.ld(t) * 47))
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

performance of tf.function

def f1(x):
    return np.mean(x**2 + x*3)

def f11(x):
    return x**2 + x*3

@tf.function
def f2(x):
    return np.mean(x**2+x*3)

def f3(x):
    return tf.reduce_mean(x**2+x**3)

@tf.function
def f4(x):
    return tf.reduce_mean(x**2+x**3)

@tf.function
def f5(x):
    return f3(x)
X = np.random.random(size=(1000,20)).astype(np.float32)
tX = tf.Variable(X, dtype=tf.float32)
# f2(X) --> error, why?
f1(X), f3(X), f4(X), f5(X)
(1.8393676,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5855545>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5855545>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5855545>)
f1(tX), f3(tX), f4(tX), f5(tX)
(1.8252264,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.57828975>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5782897>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5782897>)
# but
f11(tX)
<tf.Tensor: shape=(1000, 20), dtype=float32, numpy=
array([[2.3446867 , 3.1720169 , 2.4340954 , ..., 0.94587123, 3.915599  ,
        1.3334954 ],
       [3.6761417 , 0.35400784, 0.285192  , ..., 3.8192358 , 3.4868982 ,
        0.13817935],
       [3.5109763 , 2.8115458 , 1.7549498 , ..., 2.0345583 , 0.21928157,
        3.7829554 ],
       ...,
       [0.37408376, 2.8192651 , 2.4856312 , ..., 1.3615996 , 2.8233917 ,
        3.7900527 ],
       [0.64968884, 2.9891465 , 2.6093905 , ..., 0.15869308, 2.1982477 ,
        0.26403534],
       [3.0967698 , 2.259425  , 0.6859213 , ..., 0.01919816, 2.8295956 ,
        2.8226109 ]], dtype=float32)>
def f1(x):
    return np.mean(x**2 + x*3)

%timeit f1(X)
27 µs ± 254 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f1(tX)
326 µs ± 8.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
def f3(x):
    return tf.reduce_mean(x**2+x**3)

%timeit f3(tX)
512 µs ± 12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
@tf.function
def f4(x):
    return tf.reduce_mean(x**2+x**3)

%timeit f4(tX)
128 µs ± 2.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f4.python_function(tX)
498 µs ± 4.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
@tf.function
def f5(x):
    return f3(x)

%timeit f5(tX)
128 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Underlying concrete functions are actual TF graphs with no polymorphism, tied to specific input types

tf.function maps python polymorphism to a set of different underlying concrete functions

@tf.function
def f(x):
    return x+x
f(10), f(10.), f("a")
(<tf.Tensor: shape=(), dtype=int32, numpy=20>,
 <tf.Tensor: shape=(), dtype=int32, numpy=20>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'aa'>)

observe different hash codes for each concrete function

fs = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
fs, fs(tf.constant("aa"))
(<tensorflow.python.eager.function.ConcreteFunction at 0x7f681e3562e0>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'aaaa'>)
fi = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.int32))
fi, fi(tf.constant(1))
(<tensorflow.python.eager.function.ConcreteFunction at 0x7f68143064c0>,
 <tf.Tensor: shape=(), dtype=int32, numpy=2>)
ff = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.float32))
ff, ff(tf.constant(1.))
(<tensorflow.python.eager.function.ConcreteFunction at 0x7f6814388c10>,
 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

tf.function with keras layers

import numpy as np
np.random.seed(0)
data = np.random.randn(3, 2)
data
array([[ 1.76405235,  0.40015721],
       [ 0.97873798,  2.2408932 ],
       [ 1.86755799, -0.97727788]])
inputer = tf.keras.layers.InputLayer(input_shape=(2))
denser1 = tf.keras.layers.Dense(4, activation='relu')
denser2 = tf.keras.layers.Dense(1, activation='sigmoid')

observe that, in eager mode, layers graphs are created as their code is being executed

def model_1(data):
    x = inputer(data)
    x = denser1(x)
    print('After the first layer:', x)
    out = denser2(x)
    print('After the second layer:', out)
    return out

print('Model output:\n', model_1(data))
print("--")
print('Model output:\n', model_1(data+1))
After the first layer: tf.Tensor(
[[0.        0.        0.        0.       ]
 [1.1272627 0.        1.5222576 0.       ]
 [0.        0.        0.        1.0462083]], shape=(3, 4), dtype=float32)
After the second layer: tf.Tensor(
[[0.5       ]
 [0.8619154 ]
 [0.61063063]], shape=(3, 1), dtype=float32)
Model output:
 tf.Tensor(
[[0.5       ]
 [0.8619154 ]
 [0.61063063]], shape=(3, 1), dtype=float32)
--
After the first layer: tf.Tensor(
[[0.27164382 0.         0.         0.        ]
 [1.5170838  0.         1.8968623  0.        ]
 [0.         0.         0.         0.18906634]], shape=(3, 4), dtype=float32)
After the second layer: tf.Tensor(
[[0.5418302 ]
 [0.9130416 ]
 [0.52031773]], shape=(3, 1), dtype=float32)
Model output:
 tf.Tensor(
[[0.5418302 ]
 [0.9130416 ]
 [0.52031773]], shape=(3, 1), dtype=float32)

however, with tf.function, FIRST the function is traced resulting in a computational graph, which is what is THEN used in subsequent calls

@tf.function
def model_2(data):
    x = inputer(data)
    x = denser1(x)
    print('After the first layer:', x)
    out = denser2(x)
    print('After the second layer:', out)
    return out


print('Model\'s output:', model_2(data))
print('--')
print('Model\'s output:', model_2(data+1))
After the first layer: Tensor("dense/Relu:0", shape=(3, 4), dtype=float32)
After the second layer: Tensor("dense_1/Sigmoid:0", shape=(3, 1), dtype=float32)
Model's output: tf.Tensor(
[[0.5       ]
 [0.8619154 ]
 [0.61063063]], shape=(3, 1), dtype=float32)
--
Model's output: tf.Tensor(
[[0.5418302 ]
 [0.9130416 ]
 [0.52031773]], shape=(3, 1), dtype=float32)

tf.function usually requires less compute time, since in eager mode, everytime the function is called the graph is created

def model_1(data):
    x = inputer(data)
    x = denser1(x)
    out = denser2(x)
    return out

@tf.function
def model_2(data):
    x = inputer(data)
    x = denser1(x)
    out = denser2(x)
    return out
%timeit model_1(data)
330 µs ± 4.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit model_2(data)
101 µs ± 1.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

including graphs in upstream functions.

observe how we compute the gradient of a computational graph:

  • with model_1 the graph is generated eagerly each time the function is called

  • with model_2 the graph is only generated in the first call

def g1(data):
    with tf.GradientTape() as t:
        y = model_1(data)

    return t.gradient(y, denser1.variables)

def g2(data):
    with tf.GradientTape() as t:
        y = model_2(data)

    return t.gradient(y, denser1.variables)

g2(data), g1(data)
([<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
  array([[ 0.07191873,  0.        ,  0.08687735,  0.19097383],
         [ 0.16466327,  0.        ,  0.19891211, -0.09993505]],
        dtype=float32)>,
  <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.07348109, 0.        , 0.08876466, 0.10225859], dtype=float32)>],
 [<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
  array([[ 0.07191873,  0.        ,  0.08687735,  0.19097383],
         [ 0.16466327,  0.        ,  0.19891211, -0.09993505]],
        dtype=float32)>,
  <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.07348109, 0.        , 0.08876466, 0.10225859], dtype=float32)>])
%timeit g1(data)
686 µs ± 8.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit g2(data)
406 µs ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

however, even in g2 the gradient graph is still computed eagerly.

if we wrap either function, now everything is a cached computational graph.

fg1 = tf.function(g1)
%timeit fg1(data)
117 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fg2 = tf.function(g2)
%timeit fg2(data)
116 µs ± 2.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)