How to save and load a trained TensorFlow model?

by damian_mills , in category: General Help , 3 months ago

How to save and load a trained TensorFlow model?

Facebook Twitter LinkedIn Telegram Whatsapp

1 answer

by gabrielle.kub , 3 months ago


To save and load a trained TensorFlow model, you can follow these steps:

  1. Save the Model: To save a trained model, you need to use the tf.keras API, which is TensorFlow's high-level API for building and training models. You can save the entire model or individual components such as architecture, weights, and optimizer's state. Saving the Entire Model: Use the method to save the entire model, including architecture, weights, optimizer's state, and any custom layers. For example:"my_model") This will create a file named my_model in the current directory. Saving Individual Components: If you want to save specific components, you can save the model's architecture as JSON, weights as HDF5 format, and optimizer's state as pickled format. # Save architecture as JSON model_json = model.to_json() with open("model_architecture.json", "w") as json_file: json_file.write(model_json) # Save weights as HDF5 model.save_weights("model_weights.h5") # Save optimizer's state as pickled format (optional) import pickle with open("optimizer_state.pkl", "wb") as file: pickle.dump(optimizer.get_config(), file)
  2. Load the Model: Once the model is saved, you can load it back for further use or inference. Loading the Entire Model: Use the tf.keras.models.load_model() method to load the entire model. For example: model = tf.keras.models.load_model("my_model") Loading Individual Components: If you saved individual components, you need to reconstruct the model using the loaded architecture and then load the saved weights and optimizer state (if needed). # Load architecture from JSON with open("model_architecture.json", "r") as json_file: loaded_model_json = model = tf.keras.models.model_from_json(loaded_model_json) # Load weights from HDF5 model.load_weights("model_weights.h5") # Load optimizer's state (optional) import pickle with open("optimizer_state.pkl", "rb") as file: optimizer_config = pickle.load(file) optimizer = tf.keras.optimizers.Adam.from_config(optimizer_config) model.compile(optimizer=optimizer, ... ) # Compile the model

Remember to replace my_model, model_architecture.json, model_weights.h5, and optimizer_state.pkl with your desired filenames.