3.3 - Using tf.function#

!wget -nc --no-cache -O init.py -q https://raw.githubusercontent.com/rramosp/2021.deeplearning/main/content/init.py
import init; init.init(force_download=False);
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
%load_ext tensorboard

from sklearn.datasets import *
from local.lib import mlutils
tf.__version__
'2.19.0'

tf.function automatically converts pythonic code to a computational graph, using Tensors#

def f(x):
    return x**2 + x*3
f(2)
10
@tf.function
def f(x):
    return x**2 + x*3
f(2)
<tf.Tensor: shape=(), dtype=int32, numpy=10>

and also works with a symbolic tensor

x = tf.Variable(3.)
f(x)
<tf.Tensor: shape=(), dtype=float32, numpy=18.0>

a tf.function is traced (converted to computation graph) the first time it is executed, then it is cached IF IT IS REUSED WITH THE SAME TF VARIABLES

@tf.function
def f47(t):
    print('Tracing!')
    tf.print('Executing')
    return t**2 + t*47
f47(2)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=98>
f47(2)
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=98>

observe that if the type changes, the function is traced again since a different computational graph must be created

x = tf.Variable(2, dtype=tf.float32)
f47(x)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
x.assign(3.4)
f47(x)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=171.36000061035156>

tracing happens for EACH VARIABLE

y = tf.Variable(2, dtype=tf.float32)
f47(y)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
f47(y)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
f47(x)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=171.36000061035156>
x = tf.Variable(3, dtype=tf.int32)
f47(x)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=150>
x.assign(9)
f47(x)
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=504>
print (f47.pretty_printed_concrete_signatures())
Input Parameters:
  t (POSITIONAL_OR_KEYWORD): Literal[2]
Output Type:
  TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
  None

Input Parameters:
  t (POSITIONAL_OR_KEYWORD): VariableSpec(shape=(), dtype=tf.float32, trainable=True, alias_id=0)
Output Type:
  TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  t (POSITIONAL_OR_KEYWORD): VariableSpec(shape=(), dtype=tf.int32, trainable=True, alias_id=0)
Output Type:
  TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
  None

observe the actual generated code by tf.autograph

print(tf.autograph.to_code(f47.python_function))
def tf__f47(t):
    with ag__.FunctionScope('f47', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()
        ag__.ld(print)('Tracing!')
        ag__.converted_call(ag__.ld(tf).print, ('Executing',), None, fscope)
        try:
            do_return = True
            retval_ = ag__.ld(t) ** 2 + ag__.ld(t) * 47
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

performance of tf.function#

def f1(x):
    return np.mean(x**2 + x*3)

def f11(x):
    return x**2 + x*3

@tf.function
def f2(x):
    return np.mean(x**2+x*3)

def f3(x):
    return tf.reduce_mean(x**2+x**3)

@tf.function
def f4(x):
    return tf.reduce_mean(x**2+x**3)

@tf.function
def f5(x):
    return f3(x)
X = np.random.random(size=(1000,20)).astype(np.float32)
tX = tf.Variable(X, dtype=tf.float32)
# f2(X) --> error, why?
f1(X), f3(X), f4(X), f5(X)
(np.float32(1.8400924),
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>)
f1(tX), f3(tX), f4(tX), f5(tX)
(np.float32(1.8400924),
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>)
# but
f11(tX)
<tf.Tensor: shape=(1000, 20), dtype=float32, numpy=
array([[2.5066283 , 3.566715  , 0.21405369, ..., 1.1387894 , 0.21747205,
        0.5582415 ],
       [0.8427817 , 2.307036  , 2.4168954 , ..., 1.2836732 , 1.1275133 ,
        2.156976  ],
       [3.8027775 , 0.8254186 , 2.5732207 , ..., 1.3825173 , 3.4700062 ,
        2.5409565 ],
       ...,
       [0.8170469 , 3.4651794 , 1.0635482 , ..., 0.04916687, 0.48021984,
        2.854732  ],
       [0.7078032 , 1.1500525 , 3.8444602 , ..., 3.0917153 , 0.8404596 ,
        3.0876915 ],
       [2.2794511 , 2.3670022 , 3.5228426 , ..., 0.56031686, 3.0963511 ,
        0.1286898 ]], dtype=float32)>
def f1(x):
    return np.mean(x**2 + x*3)

%timeit f1(X)
33 µs ± 995 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f1(tX)
1.19 ms ± 152 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
def f3(x):
    return tf.reduce_mean(x**2+x**3)

%timeit f3(tX)
1.87 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
@tf.function
def f4(x):
    return tf.reduce_mean(x**2+x**3)

%timeit f4(tX)
461 µs ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit f4.python_function(tX)
2.14 ms ± 85.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
@tf.function
def f5(x):
    return f3(x)

%timeit f5(tX)
451 µs ± 9.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Underlying concrete functions are actual TF graphs with no polymorphism, tied to specific input types#

tf.function maps python polymorphism to a set of different underlying concrete functions

@tf.function
def f(x):
    return x+x
f(10), f(10.), f("a")
(<tf.Tensor: shape=(), dtype=int32, numpy=20>,
 <tf.Tensor: shape=(), dtype=int32, numpy=20>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'aa'>)

observe different hash codes for each concrete function

fs = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
fs, fs(tf.constant("aa"))
(<ConcreteFunction (x: TensorSpec(shape=<unknown>, dtype=tf.string, name=None)) -> TensorSpec(shape=<unknown>, dtype=tf.string, name=None) at 0x7F406A8BA660>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'aaaa'>)
fi = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.int32))
fi, fi(tf.constant(1))
(<ConcreteFunction (x: TensorSpec(shape=<unknown>, dtype=tf.int32, name=None)) -> TensorSpec(shape=<unknown>, dtype=tf.int32, name=None) at 0x7F406A8BB7D0>,
 <tf.Tensor: shape=(), dtype=int32, numpy=2>)
ff = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.float32))
ff, ff(tf.constant(1.))
(<ConcreteFunction (x: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None) at 0x7F406B9E01D0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

tf.function with keras layers#

import numpy as np
np.random.seed(0)
data = np.random.randn(3, 2)
data
array([[ 1.76405235,  0.40015721],
       [ 0.97873798,  2.2408932 ],
       [ 1.86755799, -0.97727788]])
denser1 = tf.keras.layers.Dense(4, activation='relu')
denser2 = tf.keras.layers.Dense(1, activation='sigmoid')

observe that, in eager mode, layers graphs are created as their code is being executed

def model_1(data):
    x = denser1(data)
    print('After the first layer:', x)
    out = denser2(x)
    print('After the second layer:', out)
    return out

print('Model output:\n', model_1(data))
print("--")
print('Model output:\n', model_1(data+1))
After the first layer: tf.Tensor(
[[0.         0.         0.36905685 0.        ]
 [1.3161621  1.5016826  0.8457896  0.20686978]
 [0.         0.         0.         0.        ]], shape=(3, 4), dtype=float32)
After the second layer: tf.Tensor(
[[0.4046358]
 [0.5677154]
 [0.5      ]], shape=(3, 1), dtype=float32)
Model output:
 tf.Tensor(
[[0.4046358]
 [0.5677154]
 [0.5      ]], shape=(3, 1), dtype=float32)
--
After the first layer: tf.Tensor(
[[0.         0.         0.82375824 0.        ]
 [1.3717937  1.9063075  1.3004911  0.        ]
 [0.         0.         0.40059814 0.        ]], shape=(3, 4), dtype=float32)
After the second layer: tf.Tensor(
[[0.29692343]
 [0.5237631 ]
 [0.3967103 ]], shape=(3, 1), dtype=float32)
Model output:
 tf.Tensor(
[[0.29692343]
 [0.5237631 ]
 [0.3967103 ]], shape=(3, 1), dtype=float32)

however, with tf.function, FIRST the function is traced resulting in a computational graph, which is what is THEN used in subsequent calls

@tf.function
def model_2(data):
    x = denser1(data)
    print('After the first layer:', x)
    out = denser2(x)
    print('After the second layer:', out)
    return out


print('Model\'s output:', model_2(data))
print('--')
print('Model\'s output:', model_2(data+1))
After the first layer: Tensor("dense_3_1/Relu:0", shape=(3, 4), dtype=float32)
After the second layer: Tensor("dense_4_1/Sigmoid:0", shape=(3, 1), dtype=float32)
Model's output: tf.Tensor(
[[0.4046358]
 [0.5677154]
 [0.5      ]], shape=(3, 1), dtype=float32)
--
Model's output: tf.Tensor(
[[0.29692343]
 [0.5237631 ]
 [0.3967103 ]], shape=(3, 1), dtype=float32)

tf.function usually requires less compute time, since in eager mode, everytime the function is called the graph is created

def model_1(data):
    x = denser1(data)
    out = denser2(x)
    return out

@tf.function
def model_2(data):
    x = denser1(data)
    out = denser2(x)
    return out
%timeit model_1(data)
2.87 ms ± 652 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit model_2(data)
618 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

including graphs in upstream functions.#

observe how we compute the gradient of a computational graph:

  • with model_1 the graph is generated eagerly each time the function is called

  • with model_2 the graph is only generated in the first call

def g1(data):
    with tf.GradientTape() as t:
        y = model_1(data)

    return t.gradient(y, denser1.variables)

def g2(data):
    with tf.GradientTape() as t:
        y = model_2(data)

    return t.gradient(y, denser1.variables)

g2(data), g1(data)
([<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
  array([[ 0.1410127 ,  0.08198048, -0.69603944, -0.1481947 ],
         [ 0.322859  ,  0.18770038, -0.67634726, -0.33930275]],
        dtype=float32)>,
  <tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 0.14407603,  0.08376142, -0.50889206, -0.15141407], dtype=float32)>],
 [<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
  array([[ 0.1410127 ,  0.08198048, -0.69603944, -0.1481947 ],
         [ 0.322859  ,  0.18770038, -0.67634726, -0.33930275]],
        dtype=float32)>,
  <tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 0.14407603,  0.08376142, -0.50889206, -0.15141407], dtype=float32)>])
%timeit g1(data)
5.2 ms ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit g2(data)
1.8 ms ± 301 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

however, even in g2 the gradient graph is still computed eagerly.

if we wrap either function, now everything is a cached computational graph.

fg1 = tf.function(g1)
%timeit fg1(data)
650 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fg2 = tf.function(g2)
%timeit fg2(data)
711 µs ± 98.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)