Handwritten Text Recognition


Handwritten Text Recognition with Tensorflow2 & Keras & IAM Dataset.

Convolutional Recurrent Neural Network. CTC.

Author : Mohsen Dehghani

Dataset used:

Used in this project: IAM Dataset

import numpy as np
import cv2
import os
import pandas as pd
import string
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.preprocessing.sequence import pad_sequences

from keras.layers import Dense, LSTM, Reshape, BatchNormalization, Input, Conv2D, MaxPool2D, Lambda, Bidirectional
from keras.models import Model
from keras.activations import relu, sigmoid, softmax
import keras.backend as K
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint
from keras_tqdm import TQDMNotebookCallback
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, BatchNormalization, Lambda, Dense, Bidirectional, LSTM

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
with open('./data/words2.txt') as f:
    contents = f.readlines()

lines = [line.strip() for line in contents] 
lines[2],lines[20]
import random

# Define the portion (e.g., 10%)
portion = 1
num_samples = int(portion * len(lines))

# Randomly select 10% of the lines
lines_subset = random.sample(lines, num_samples)

lines = lines_subset
print(f"Original size: {len(lines)}, Subset size: {len(lines_subset)}")

max_label_len = 0

char_list = "!\"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 

# string.ascii_letters + string.digits (Chars & Digits)
# or 
# "!\"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

print(char_list, len(char_list))

def encode_to_labels(txt):
    # encoding each output word into digits
    dig_lst = []
    for index, chara in enumerate(txt):
        dig_lst.append(char_list.index(chara))
        
    return dig_lst

def decode_to_labels(indices):
    """
    Decodes a list of indices back to the corresponding string using char_list.

    Args:
        indices (list of int): List of indices representing encoded characters.

    Returns:
        str: Decoded string.
    """
    decoded_str = ''.join([char_list[i] for i in indices])
    return decoded_str

# Example Usage
encoded_example = encode_to_labels("Hello123")
print("Encoded:", encoded_example)

decoded_example = decode_to_labels(encoded_example)
print("Decoded:", decoded_example)
images = []
labels = []

RECORDS_COUNT = 10000

train_images = []
train_labels = []
train_input_length = []
train_label_length = []
train_original_text = []

valid_images = []
valid_labels = []
valid_input_length = []
valid_label_length = []
valid_original_text = []

inputs_length = []
labels_length = []
def process_image(img):
    """
    Converts image to shape (32, 128, 1) & normalize
    """
    w, h = img.shape

    # Aspect Ratio Calculation
    new_w = 32
    new_h = int(h * (new_w / w))
    img = cv2.resize(img, (new_h, new_w))
    w, h = img.shape
    
    img = img.astype('float32')
    
    # Converts each to (32, 128, 1)
    if w < 32:
        add_zeros = np.full((32-w, h), 255)
        img = np.concatenate((img, add_zeros))
        w, h = img.shape
    
    if h < 128:
        add_zeros = np.full((w, 128-h), 255)
        img = np.concatenate((img, add_zeros), axis=1)
        w, h = img.shape
        
    if h > 128 or w > 32:
        dim = (128,32)
        img = cv2.resize(img, dim)
    
    img = cv2.subtract(255, img)
    
    img = np.expand_dims(img, axis=2)
    
    # Normalize 
    img = img / 255
    
    return img
for index, line in enumerate(lines):
    splits = line.split(' ')
    status = splits[1]
    
    
    if status == 'ok':
        word_id = splits[0]
        word = "".join(splits[8:])
                
        splits_id = word_id.split('-')
        filepath = 'words/{}/{}-{}/{}.png'.format(splits_id[0], 
                                                  splits_id[0], 
                                                  splits_id[1], 
                                                  word_id)
        
        # process image
        img = cv2.imread('./data/'+filepath, cv2.IMREAD_GRAYSCALE)
        
        try:
            img = process_image(img)
        except:
            continue
            
        # process label
        try:
            label = encode_to_labels(word)
        except:
            continue
        
        if index % 10 == 0:
            valid_images.append(img)
            valid_labels.append(label)
            valid_input_length.append(31)
            valid_label_length.append(len(word))
            valid_original_text.append(word)
        else:
            train_images.append(img)
            train_labels.append(label)
            train_input_length.append(31)
            train_label_length.append(len(word))
            train_original_text.append(word)
        
        if len(word) > max_label_len:
            max_label_len = len(word)
    
    if index >= RECORDS_COUNT:
        break
# Print subset shapes to verify
print("Subset shapes:", 
      len(train_images), 
      len(train_labels), 
      len(valid_images), 
      len(valid_labels))
train_padded_label = pad_sequences(train_labels, 
                             maxlen=max_label_len, 
                             padding='post',
                             value=len(char_list))

valid_padded_label = pad_sequences(valid_labels, 
                             maxlen=max_label_len, 
                             padding='post',
                             value=len(char_list))

train_images = np.asarray(train_images)
train_input_length = np.asarray(train_input_length)
train_label_length = np.asarray(train_label_length)

valid_images = np.asarray(valid_images)
valid_input_length = np.asarray(valid_input_length)
valid_label_length = np.asarray(valid_label_length)
train_images.shape,train_padded_label.shape,valid_images.shape, valid_padded_label.shape

Original size: 115320, Subset size: 115320
!"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 78
Encoded: [33, 56, 63, 63, 66, 14, 15, 16]
Decoded: Hello123
Subset shapes: 7579 7579 847 847

# Define input shape (height=32, width=128)
inputs = Input(shape=(32, 128, 1))

# Convolutional layers
conv_1 = Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
pool_1 = MaxPool2D(pool_size=(2, 2), strides=2)(conv_1)

conv_2 = Conv2D(128, (3,3), activation='relu', padding='same')(pool_1)
pool_2 = MaxPool2D(pool_size=(2, 2), strides=2)(conv_2)

conv_3 = Conv2D(256, (3,3), activation='relu', padding='same')(pool_2)
conv_4 = Conv2D(256, (3,3), activation='relu', padding='same')(conv_3)
pool_4 = MaxPool2D(pool_size=(2, 1))(conv_4)

conv_5 = Conv2D(512, (3,3), activation='relu', padding='same')(pool_4)
batch_norm_5 = BatchNormalization()(conv_5)

conv_6 = Conv2D(512, (3,3), activation='relu', padding='same')(batch_norm_5)
batch_norm_6 = BatchNormalization()(conv_6)
pool_6 = MaxPool2D(pool_size=(2, 1))(batch_norm_6)

conv_7 = Conv2D(512, (2,2), activation='relu')(pool_6)

squeezed = Lambda(lambda x: K.squeeze(x, 1))(conv_7)

# Bidirectional LSTM layers
blstm_1 = Bidirectional(LSTM(256, return_sequences=True, dropout=0.2))(squeezed)
blstm_2 = Bidirectional(LSTM(256, return_sequences=True, dropout=0.2))(blstm_1)

outputs = Dense(len(char_list) + 1, activation='softmax')(blstm_2)

# Model for prediction
act_model = Model(inputs, outputs)

# Inputs for training (additional)
the_labels = Input(name='the_labels', shape=[max_label_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')

# Define the custom CTC loss function
def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

# Add the CTC loss layer
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([outputs, the_labels, input_length, label_length])

# Training model with CTC loss
model = Model(inputs=[inputs, the_labels, input_length, label_length], outputs=loss_out)

# Compile the model
batch_size = 8
epochs = 60
optimizer_name = 'adam'
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=optimizer_name, metrics=['accuracy'])

# Callbacks setup
model_save_path = "./ocr_ctc_model"
checkpoint = ModelCheckpoint(
    filepath=model_save_path + '/best_model.h5',
    monitor='val_loss',
    verbose=1,
    save_best_only=True,
    mode='auto'
)

plot_callback = TrainingPlot()
callbacks_list = [checkpoint, plot_callback]

# Save the model without the CTC loss for better reusability
model_without_ctc = Model(inputs=inputs, outputs=outputs)
model_without_ctc.save(model_save_path, save_format='tf')

print(f"Model saved successfully to {model_save_path}")

# Training with callbacks
history = model.fit(
    x=[train_images, train_padded_label, train_input_length, train_label_length],
    y=np.zeros(len(train_images)),
    batch_size=batch_size, 
    epochs=epochs, 
    validation_data=(
        [valid_images, valid_padded_label, valid_input_length, valid_label_length],
        [np.zeros(len(valid_images))]
    ),
    verbose=1,
    callbacks=callbacks_list
)

print("Training complete. Model saved successfully.")

948/948 [==============================] - 555s 585ms/step - loss: 3.0008 - accuracy: 0.3644 - val_loss: 5.7751 - val_accuracy: 0.3542
Training complete. Model saved successfully.
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint


# Load the saved model without CTC for inference or re-training
loaded_model = load_model(model_save_path)

print("Model loaded successfully.")

# Re-add the CTC loss Lambda layer
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')(
    [loaded_model.output, the_labels, input_length, label_length]
)

# Build the complete model again with CTC loss
final_model = Model(inputs=[loaded_model.input, the_labels, input_length, label_length], outputs=loss_out)

print("CTC layer re-added successfully.")

# Compile the model again with the same optimizer
final_model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=optimizer_name, metrics=['accuracy'])

# Define Callbacks for Resuming Training
checkpoint = ModelCheckpoint(
    filepath="./ocr_ctc_model/best_model_resume.h5",
    monitor='val_loss',
    verbose=1,
    save_best_only=True,
    mode='auto'
)


# Callbacks list
callbacks_list = [checkpoint, plot_callback]

# Continue training from the last saved epoch (starting from epoch 3)
initial_epoch = 2  # Training previously done for 2 epochs

history = final_model.fit(
    x=[train_images, train_padded_label, train_input_length, train_label_length],
    y=np.zeros(len(train_images)),
    batch_size=batch_size, 
    epochs=5,  # Continue training till epoch 5
    initial_epoch=initial_epoch,  # Start from epoch 3
    validation_data=(
        [valid_images, valid_padded_label, valid_input_length, valid_label_length],
        [np.zeros(len(valid_images))]
    ),
    verbose=1,
    callbacks=callbacks_list  # Including checkpointing and plotting
)

print("Training resumed and completed successfully.")

Training Accuracy

# predict outputs on validation images
m=1
prediction = act_model.predict(train_images[m:2])
 
# use CTC decoder
decoded = K.ctc_decode(prediction,   
                       input_length=np.ones(prediction.shape[0]) * prediction.shape[1],
                       greedy=True)[0][0]

out = K.get_value(decoded)

# see the results
for i, x in enumerate(out):
    print("original_text =  ", train_original_text[m+i])
    print("predicted text = ", end = '')
    for p in x:
        if int(p) != -1:
            print(char_list[int(p)], end = '')
    plt.imshow(train_images[m+i].reshape(32,128), cmap=plt.cm.gray)
    plt.show()
    print('\n')
1/1 [==============================] - 0s 169ms/step
original_text =   taken
predicted text = taken