How to restart a GAN Training with TensorFlow 2.15 using checkpoints

  Kiến thức lập trình

I’m working on creating a Jupyter notebook for training a GAN using TensorFlow, and I want to be able to restore the last checkpoint to continue training from where I left off.

I am following this official TensorFlow tutorial. However, when I call the train() function after restoring the checkpoint, the model seems to restart from the initial state instead of continuing from the last checkpoint.

here is how I save checkpoints
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=generator, discriminator=discriminator)

here is how I restore it:
last_checkpoint_dir="/kaggle/working/training_checkpoints" latest=tf.train.latest_checkpoint(checkpoint_dir) print(f"latest: {latest}") checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

this is my training loop:
`# Training loop
g_losses = []
d_losses = []

def train(dataset,initial_epoch=0, epochs=5):
for epoch in range(initial_epoch, epochs, 1):

start = time.time()

n = dataset.cardinality() # The number of batches per epoch
# Iterate over all batches
batch_d_losses = []
batch_g_losses = []
for i, batch in enumerate(dataset):
    # Updte parameters for this batch
    g_loss, d_loss = train_step(batch)
    # Store losses for batch, we will average these for the whole epoch for a more stable visualization
    sys.stdout.write("r" + "Epoch %d - batch %d of %d "%(epoch+1, i+1, n) + "[D loss: "+str(d_loss)+ " | G loss: "+str(g_loss)+"]")

# Book-keeping:
# Visualize losses and one example image for the epoch
#     display.clear_output(wait=True)
xepoch = list(range(epoch+1))
plt.plot(d_losses, label='Discriminator')
plt.plot(g_losses, label='Generator')

for image_batch in dataset:

# Save the model every 5 epochs
if epoch % 5 == 0: # checkpoint_save_interval = checkpoint_prefix)

# Save some example images and store model file
if epoch % save_interval == 0:
    print('Saving epoch %d to %s'%(epoch+1, model_path)), "e%0d_generator.h5"%(epoch+1)))

print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

# Generate after the final epoch
example = generate_and_save_images(generator, epoch, seed)`

I tried restoring and calling the train() function but the results of the new training seem to restart from the initial condition.

New contributor

user16493937 is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.

Theme wordpress giá rẻ Theme wordpress giá rẻ Thiết kế website