On this page
article
Training Neural Networks with TensorFlow
Train, evaluate, and tune deep learning models in TensorFlow — callbacks, data pipelines, transfer learning, and saving models.
Building a model is step one. Training it effectively — with proper data pipelines, callbacks, and evaluation — is what separates working models from production-ready ones.
Data Pipeline with tf.data
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
def preprocess(image, label):
image = tf.cast(image, tf.float32) / 255.0
return image, label
train_ds = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(10000)
.map(preprocess)
.batch(64)
.prefetch(tf.data.AUTOTUNE)
)
test_ds = (
tf.data.Dataset.from_tensor_slices((x_test, y_test))
.map(preprocess)
.batch(64)
.prefetch(tf.data.AUTOTUNE)
)
prefetch overlaps data loading with training for better GPU utilization.
CNN for Image Classification
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation="relu"),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation="softmax"),
])
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
Training with Callbacks
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=5, restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=3
),
tf.keras.callbacks.ModelCheckpoint(
"best_model.keras", save_best_only=True
),
tf.keras.callbacks.TensorBoard(log_dir="./logs"),
]
history = model.fit(
train_ds,
epochs=50,
validation_data=test_ds,
callbacks=callbacks,
)
Visualizing Training History
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history["loss"], label="train")
ax1.plot(history.history["val_loss"], label="val")
ax1.set_title("Loss")
ax1.legend()
ax2.plot(history.history["accuracy"], label="train")
ax2.plot(history.history["val_accuracy"], label="val")
ax2.set_title("Accuracy")
ax2.legend()
plt.show()
Data Augmentation
Prevent overfitting with on-the-fly augmentation:
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.1),
tf.keras.layers.RandomZoom(0.1),
])
model = tf.keras.Sequential([
data_augmentation,
tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(32, 32, 3)),
# ... rest of model
])
Transfer Learning
base = tf.keras.applications.EfficientNetB0(
include_top=False, weights="imagenet", input_shape=(224, 224, 3)
)
base.trainable = False
inputs = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.applications.efficientnet.preprocess_input(inputs)
x = base(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10, activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
Fine-Tuning
After initial training, unfreeze top layers:
base.trainable = True
for layer in base.layers[:-20]:
layer.trainable = False
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-5),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
model.fit(train_ds, epochs=10, validation_data=test_ds)
Export for Production
model.save("final_model.keras")
model.export("saved_model") # TensorFlow Serving format
# TensorFlow Lite for mobile
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model")
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
f.write(tflite_model)
Effective training workflows — data pipelines, callbacks, augmentation, and transfer learning — are essential skills for any ML engineer.