learning_rate=FLAGS.learning_rate) # Compute vector of per-example # loss rather than its mean over a minibatch. loss = tf.keras.losses.CategoricalCrossentropy( from_logits=True, reduction=tf.losses.Reduction.NONE) # Compile model with Keras model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])