diff --git a/tensorflow_gnn/runner/orchestration.py b/tensorflow_gnn/runner/orchestration.py index eba14077..1ddab9e2 100644 --- a/tensorflow_gnn/runner/orchestration.py +++ b/tensorflow_gnn/runner/orchestration.py @@ -315,13 +315,20 @@ def apply_fn(ds, drop_remainder, global_batch_size) - def adapted_model_fn(): + def adapted_model_fn(options = None): if isinstance(preprocess_model.output, collections.abc.Sequence): x, *_ = preprocess_model.output else: x = preprocess_model.output m = task.adapt(model_fn(x.spec)) optimizer = optimizer_fn() + if options and options.policy: + # Cast logits to `tf.keras.backend.floatx()` for mixed_precision. + # For more details, see: + # https://www.tensorflow.org/guide/mixed_precision#building_the_model. + floatx = tf.keras.backend.floatx() + outputs = [tf.cast(o, dtype=floatx) for o in m.outputs] + m = tf.keras.Model(m.inputs, outputs) if train_padding is None: m.compile(optimizer, loss=task.losses(), metrics=task.metrics()) else: diff --git a/tensorflow_gnn/runner/trainers/keras_fit.py b/tensorflow_gnn/runner/trainers/keras_fit.py index 16c17133..c46ca001 100644 --- a/tensorflow_gnn/runner/trainers/keras_fit.py +++ b/tensorflow_gnn/runner/trainers/keras_fit.py @@ -269,14 +269,6 @@ def per_replica_ds_fn(input_context, *, delegate, repeat): with self._strategy.scope(): model = model_fn() - if self._options and self._options.policy: - # Cast logits to `tf.keras.backend.floatx()` for mixed_precision. - # For more details, see: - # https://www.tensorflow.org/guide/mixed_precision#building_the_model. - floatx = tf.keras.backend.floatx() - outputs = [tf.cast(o, dtype=floatx) for o in model.outputs] - model = tf.keras.Model(model.inputs, outputs) - model.fit( train_ds, epochs=epochs,