The solution is easy to implement. The idea is to 1. like the current code, draw many samples of
y_pred through repeated
sess.run calls; 2. take the mode of the samples, which will recover the most likely discrete value.
This code would be implemented as an if-clause in the current prediction source code. Namely, we run this
if isinstance(output_key, tf.Tensor) and output_key.dtype not in [tf.float16, tf.float32, tf.float64, tf.bfloat16, tf.complex64, tf.complex128].