How to restart a GAN Training with TensorFlow 2.15 using checkpoints

I’m working on creating a Jupyter notebook for training a GAN using TensorFlow, and I want to be able to restore the last checkpoint to continue training from where I left off.

I am following this official TensorFlow tutorial. However, when I call the train() function after restoring the checkpoint, the model seems to restart from the initial state instead of continuing from the last checkpoint.

here is how I save checkpoints
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=generator, discriminator=discriminator)

here is how I restore it:
last_checkpoint_dir="/kaggle/working/training_checkpoints" latest=tf.train.latest_checkpoint(checkpoint_dir) print(f"latest: {latest}") checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

this is my training loop:
`# Training loop
g_losses = []
d_losses = []

def train(dataset,initial_epoch=0, epochs=5):
for epoch in range(initial_epoch, epochs, 1):

start = time.time()

n = dataset.cardinality() # The number of batches per epoch
# Iterate over all batches
batch_d_losses = []
batch_g_losses = []
for i, batch in enumerate(dataset):
    # Updte parameters for this batch
    g_loss, d_loss = train_step(batch)
    # Store losses for batch, we will average these for the whole epoch for a more stable visualization
    sys.stdout.write("r" + "Epoch %d - batch %d of %d "%(epoch+1, i+1, n) + "[D loss: "+str(d_loss)+ " | G loss: "+str(g_loss)+"]")

# Book-keeping:
# Visualize losses and one example image for the epoch
#     display.clear_output(wait=True)
xepoch = list(range(epoch+1))
plt.plot(d_losses, label='Discriminator')
plt.plot(g_losses, label='Generator')

for image_batch in dataset:

# Save the model every 5 epochs
if epoch % 5 == 0: # checkpoint_save_interval = checkpoint_prefix)

# Save some example images and store model file
if epoch % save_interval == 0:
    print('Saving epoch %d to %s'%(epoch+1, model_path)), "e%0d_generator.h5"%(epoch+1)))

print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

# Generate after the final epoch
example = generate_and_save_images(generator, epoch, seed)`

I tried restoring and calling the train() function but the results of the new training seem to restart from the initial condition.

