3.3 - Using tf.function
#
!wget -nc --no-cache -O init.py -q https://raw.githubusercontent.com/rramosp/2021.deeplearning/main/content/init.py
import init; init.init(force_download=False);
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
%load_ext tensorboard
from sklearn.datasets import *
from local.lib import mlutils
tf.__version__
'2.19.0'
tf.function
automatically converts pythonic code to a computational graph, using Tensors#
def f(x):
return x**2 + x*3
f(2)
10
@tf.function
def f(x):
return x**2 + x*3
f(2)
<tf.Tensor: shape=(), dtype=int32, numpy=10>
and also works with a symbolic tensor
x = tf.Variable(3.)
f(x)
<tf.Tensor: shape=(), dtype=float32, numpy=18.0>
a tf.function
is traced (converted to computation graph) the first time it is executed, then it is cached IF IT IS REUSED WITH THE SAME TF VARIABLES
@tf.function
def f47(t):
print('Tracing!')
tf.print('Executing')
return t**2 + t*47
f47(2)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=98>
f47(2)
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=98>
observe that if the type changes, the function is traced again since a different computational graph must be created
x = tf.Variable(2, dtype=tf.float32)
f47(x)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
x.assign(3.4)
f47(x)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=171.36000061035156>
tracing happens for EACH VARIABLE
y = tf.Variable(2, dtype=tf.float32)
f47(y)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
f47(y)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=98.0>
f47(x)
Executing
<tf.Tensor: shape=(), dtype=float32, numpy=171.36000061035156>
x = tf.Variable(3, dtype=tf.int32)
f47(x)
Tracing!
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=150>
x.assign(9)
f47(x)
Executing
<tf.Tensor: shape=(), dtype=int32, numpy=504>
print (f47.pretty_printed_concrete_signatures())
Input Parameters:
t (POSITIONAL_OR_KEYWORD): Literal[2]
Output Type:
TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
None
Input Parameters:
t (POSITIONAL_OR_KEYWORD): VariableSpec(shape=(), dtype=tf.float32, trainable=True, alias_id=0)
Output Type:
TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
None
Input Parameters:
t (POSITIONAL_OR_KEYWORD): VariableSpec(shape=(), dtype=tf.int32, trainable=True, alias_id=0)
Output Type:
TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
None
observe the actual generated code by tf.autograph
print(tf.autograph.to_code(f47.python_function))
def tf__f47(t):
with ag__.FunctionScope('f47', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
do_return = False
retval_ = ag__.UndefinedReturnValue()
ag__.ld(print)('Tracing!')
ag__.converted_call(ag__.ld(tf).print, ('Executing',), None, fscope)
try:
do_return = True
retval_ = ag__.ld(t) ** 2 + ag__.ld(t) * 47
except:
do_return = False
raise
return fscope.ret(retval_, do_return)
performance of tf.function
#
def f1(x):
return np.mean(x**2 + x*3)
def f11(x):
return x**2 + x*3
@tf.function
def f2(x):
return np.mean(x**2+x*3)
def f3(x):
return tf.reduce_mean(x**2+x**3)
@tf.function
def f4(x):
return tf.reduce_mean(x**2+x**3)
@tf.function
def f5(x):
return f3(x)
X = np.random.random(size=(1000,20)).astype(np.float32)
tX = tf.Variable(X, dtype=tf.float32)
# f2(X) --> error, why?
f1(X), f3(X), f4(X), f5(X)
(np.float32(1.8400924),
<tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
<tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
<tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>)
f1(tX), f3(tX), f4(tX), f5(tX)
(np.float32(1.8400924),
<tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
<tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>,
<tf.Tensor: shape=(), dtype=float32, numpy=0.5865188837051392>)
# but
f11(tX)
<tf.Tensor: shape=(1000, 20), dtype=float32, numpy=
array([[2.5066283 , 3.566715 , 0.21405369, ..., 1.1387894 , 0.21747205,
0.5582415 ],
[0.8427817 , 2.307036 , 2.4168954 , ..., 1.2836732 , 1.1275133 ,
2.156976 ],
[3.8027775 , 0.8254186 , 2.5732207 , ..., 1.3825173 , 3.4700062 ,
2.5409565 ],
...,
[0.8170469 , 3.4651794 , 1.0635482 , ..., 0.04916687, 0.48021984,
2.854732 ],
[0.7078032 , 1.1500525 , 3.8444602 , ..., 3.0917153 , 0.8404596 ,
3.0876915 ],
[2.2794511 , 2.3670022 , 3.5228426 , ..., 0.56031686, 3.0963511 ,
0.1286898 ]], dtype=float32)>
def f1(x):
return np.mean(x**2 + x*3)
%timeit f1(X)
33 µs ± 995 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f1(tX)
1.19 ms ± 152 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
def f3(x):
return tf.reduce_mean(x**2+x**3)
%timeit f3(tX)
1.87 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
@tf.function
def f4(x):
return tf.reduce_mean(x**2+x**3)
%timeit f4(tX)
461 µs ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit f4.python_function(tX)
2.14 ms ± 85.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
@tf.function
def f5(x):
return f3(x)
%timeit f5(tX)
451 µs ± 9.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Underlying concrete
functions are actual TF graphs with no polymorphism, tied to specific input types#
tf.function
maps python polymorphism to a set of different underlying concrete functions
@tf.function
def f(x):
return x+x
f(10), f(10.), f("a")
(<tf.Tensor: shape=(), dtype=int32, numpy=20>,
<tf.Tensor: shape=(), dtype=int32, numpy=20>,
<tf.Tensor: shape=(), dtype=string, numpy=b'aa'>)
observe different hash codes for each concrete function
fs = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
fs, fs(tf.constant("aa"))
(<ConcreteFunction (x: TensorSpec(shape=<unknown>, dtype=tf.string, name=None)) -> TensorSpec(shape=<unknown>, dtype=tf.string, name=None) at 0x7F406A8BA660>,
<tf.Tensor: shape=(), dtype=string, numpy=b'aaaa'>)
fi = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.int32))
fi, fi(tf.constant(1))
(<ConcreteFunction (x: TensorSpec(shape=<unknown>, dtype=tf.int32, name=None)) -> TensorSpec(shape=<unknown>, dtype=tf.int32, name=None) at 0x7F406A8BB7D0>,
<tf.Tensor: shape=(), dtype=int32, numpy=2>)
ff = f.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.float32))
ff, ff(tf.constant(1.))
(<ConcreteFunction (x: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None) at 0x7F406B9E01D0>,
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
tf.function
with keras
layers#
import numpy as np
np.random.seed(0)
data = np.random.randn(3, 2)
data
array([[ 1.76405235, 0.40015721],
[ 0.97873798, 2.2408932 ],
[ 1.86755799, -0.97727788]])
denser1 = tf.keras.layers.Dense(4, activation='relu')
denser2 = tf.keras.layers.Dense(1, activation='sigmoid')
observe that, in eager mode, layers graphs are created as their code is being executed
def model_1(data):
x = denser1(data)
print('After the first layer:', x)
out = denser2(x)
print('After the second layer:', out)
return out
print('Model output:\n', model_1(data))
print("--")
print('Model output:\n', model_1(data+1))
After the first layer: tf.Tensor(
[[0. 0. 0.36905685 0. ]
[1.3161621 1.5016826 0.8457896 0.20686978]
[0. 0. 0. 0. ]], shape=(3, 4), dtype=float32)
After the second layer: tf.Tensor(
[[0.4046358]
[0.5677154]
[0.5 ]], shape=(3, 1), dtype=float32)
Model output:
tf.Tensor(
[[0.4046358]
[0.5677154]
[0.5 ]], shape=(3, 1), dtype=float32)
--
After the first layer: tf.Tensor(
[[0. 0. 0.82375824 0. ]
[1.3717937 1.9063075 1.3004911 0. ]
[0. 0. 0.40059814 0. ]], shape=(3, 4), dtype=float32)
After the second layer: tf.Tensor(
[[0.29692343]
[0.5237631 ]
[0.3967103 ]], shape=(3, 1), dtype=float32)
Model output:
tf.Tensor(
[[0.29692343]
[0.5237631 ]
[0.3967103 ]], shape=(3, 1), dtype=float32)
however, with tf.function
, FIRST the function is traced resulting in a computational graph, which is what is THEN used in subsequent calls
@tf.function
def model_2(data):
x = denser1(data)
print('After the first layer:', x)
out = denser2(x)
print('After the second layer:', out)
return out
print('Model\'s output:', model_2(data))
print('--')
print('Model\'s output:', model_2(data+1))
After the first layer: Tensor("dense_3_1/Relu:0", shape=(3, 4), dtype=float32)
After the second layer: Tensor("dense_4_1/Sigmoid:0", shape=(3, 1), dtype=float32)
Model's output: tf.Tensor(
[[0.4046358]
[0.5677154]
[0.5 ]], shape=(3, 1), dtype=float32)
--
Model's output: tf.Tensor(
[[0.29692343]
[0.5237631 ]
[0.3967103 ]], shape=(3, 1), dtype=float32)
tf.function
usually requires less compute time, since in eager mode, everytime the function is called the graph is created
def model_1(data):
x = denser1(data)
out = denser2(x)
return out
@tf.function
def model_2(data):
x = denser1(data)
out = denser2(x)
return out
%timeit model_1(data)
2.87 ms ± 652 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit model_2(data)
618 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
including graphs in upstream functions.#
observe how we compute the gradient of a computational graph:
with
model_1
the graph is generated eagerly each time the function is calledwith
model_2
the graph is only generated in the first call
def g1(data):
with tf.GradientTape() as t:
y = model_1(data)
return t.gradient(y, denser1.variables)
def g2(data):
with tf.GradientTape() as t:
y = model_2(data)
return t.gradient(y, denser1.variables)
g2(data), g1(data)
([<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
array([[ 0.1410127 , 0.08198048, -0.69603944, -0.1481947 ],
[ 0.322859 , 0.18770038, -0.67634726, -0.33930275]],
dtype=float32)>,
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 0.14407603, 0.08376142, -0.50889206, -0.15141407], dtype=float32)>],
[<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
array([[ 0.1410127 , 0.08198048, -0.69603944, -0.1481947 ],
[ 0.322859 , 0.18770038, -0.67634726, -0.33930275]],
dtype=float32)>,
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 0.14407603, 0.08376142, -0.50889206, -0.15141407], dtype=float32)>])
%timeit g1(data)
5.2 ms ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit g2(data)
1.8 ms ± 301 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
however, even in g2
the gradient graph is still computed eagerly.
if we wrap either function, now everything is a cached computational graph.
fg1 = tf.function(g1)
%timeit fg1(data)
650 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fg2 = tf.function(g2)
%timeit fg2(data)
711 µs ± 98.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)