The solution is easy to implement. The idea is to 1. like the current code, draw many samples of y_pred through repeated sess.run calls; 2. take the mode of the samples, which will recover the most likely discrete value.
This code would be implemented as an if-clause in the current prediction source code. Namely, we run this if isinstance(output_key, tf.Tensor) and output_key.dtype not in [tf.float16, tf.float32, tf.float64, tf.bfloat16, tf.complex64, tf.complex128].
Hi @dustin,
I tried your help, but I didn’t get where I should add the aforementioned code. could you please let me know where and into which line in evaluate module, I should add the code.
Thanks for your time.
The clause attempts to form predictions y_pred by Monte Carlo estimating the posterior predictive mean. The way this is done depends on output_key’s data type. Therefore your proposed extension should be part of the if-clause, extending and not replacing the other cases. Does that make sense?