I’m working on creating a Jupyter notebook for training a GAN using TensorFlow, and I want to be able to restore the last checkpoint to continue training from where I left off.
I am following this official TensorFlow tutorial. However, when I call the train() function after restoring the checkpoint, the model seems to restart from the initial state instead of continuing from the last checkpoint.
here is how I save checkpoints
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=generator, discriminator=discriminator)
here is how I restore it:
last_checkpoint_dir="/kaggle/working/training_checkpoints" latest=tf.train.latest_checkpoint(checkpoint_dir) print(f"latest: {latest}") checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
this is my training loop:
`# Training loop
g_losses = []
d_losses = []
def train(dataset,initial_epoch=0, epochs=5):
for epoch in range(initial_epoch, epochs, 1):
start = time.time()
n = dataset.cardinality() # The number of batches per epoch
# Iterate over all batches
batch_d_losses = []
batch_g_losses = []
for i, batch in enumerate(dataset):
# Updte parameters for this batch
g_loss, d_loss = train_step(batch)
# Store losses for batch, we will average these for the whole epoch for a more stable visualization
batch_g_losses.append(g_loss)
batch_d_losses.append(d_loss)
sys.stdout.write("r" + "Epoch %d - batch %d of %d "%(epoch+1, i+1, n) + "[D loss: "+str(d_loss)+ " | G loss: "+str(g_loss)+"]")
# Book-keeping:
# Visualize losses and one example image for the epoch
g_losses.append(np.mean(batch_g_losses))
d_losses.append(np.mean(batch_d_losses))
# display.clear_output(wait=True)
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title('Losses')
xepoch = list(range(epoch+1))
plt.plot(d_losses, label='Discriminator')
plt.plot(g_losses, label='Generator')
plt.legend()
for image_batch in dataset:
train_step(image_batch)
# Save the model every 5 epochs
if epoch % 5 == 0: # checkpoint_save_interval
checkpoint.save(file_prefix = checkpoint_prefix)
# Save some example images and store model file
if epoch % save_interval == 0:
print('Saving epoch %d to %s'%(epoch+1, model_path))
generator.save(os.path.join(model_path, "e%0d_generator.h5"%(epoch+1)))
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
example = generate_and_save_images(generator, epoch, seed)`
I tried restoring and calling the train() function but the results of the new training seem to restart from the initial condition.