Custom random variable with two parts

I am starting to see the real problem now. So the distributions add up the log prob of each data point.

In My model the dimensions are [N, T, 2]. Calculation of prob requires matrix multiplication along the time dimension and can not be set to log form at the element level. Log can be applied after all the elements have gone thru matrix multiplication and then multiplied to each other. It’s ok to take log prob for each batch sample (N axis), but not for each time element (T axis) in a single batch sample.

Can you think of any way to take care of this situation?