Source code for pyotc.otc_backend.policy_iteration.sparse.exact_tci

"""
Original Transition Coupling Improvements (TCI) method from:
https://www.jmlr.org/papers/volume23/21-0519/21-0519.pdf
"""

import numpy as np
import scipy.sparse as sp
from pyotc.otc_backend.optimal_transport.pot import computeot_pot


[docs] def setup_ot(f, Px, Py, R): """ This improvement step updates the transition coupling matrix R that minimizes the product Rf element-wise. In more detail, we may select a transition coupling R such that for each state pair (x, y), the corresponding row r = R((x, y), ·) minimizes rf over couplings r in Pi(Px(x, ·), Py(y, ·)). This is done by solving the optimal transport problem for each state pair (x, y) in the source and target Markov chains. The resulting transition coupling matrix R is updated accordingly. This function uses the POT (Python Optimal Transport) library to solve the optimal transport problem for each (x, y) state pair and updates the transition coupling matrix. Args: f (np.ndarray): Cost function reshaped as of shape (dx*dy,). Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx). Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy). R (np.ndarray): Transition coupling matrix to update of shape (dx*dy, dx*dy). Returns: R (np.ndarray): Updated transition coupling matrix of shape (dx*dy, dx*dy). """ dx, dy = Px.shape[0], Py.shape[0] f_mat = np.reshape(f, (dx, dy)) for x_row in range(dx): for y_row in range(dy): dist_x = Px[x_row, :] dist_y = Py[y_row, :] # Check if either distribution is degenerate. if np.any(dist_x == 1) or np.any(dist_y == 1): sol = np.outer(dist_x, dist_y) else: sol, _ = computeot_pot(f_mat, dist_x, dist_y) idx = dy * x_row + y_row sol_flat = sol.flatten() for j in np.nonzero(sol_flat)[0]: R[idx, j] = sol_flat[j] return R
[docs] def exact_tci(g, h, R0, Px, Py): """ Performs the Transition Coupling Improvement (TCI) step in the OTC algorithm. This function attempts to update the current coupling transition matrix R0 based on the evaluation vectors g and h obtained from the Transition Coupling Evaluation (TCE). Args: g (np.ndarray): Gain vector from TCE of shape (dx*dy,). h (np.ndarray): Bias vector from TCE of shape (dx*dy,). R0 (scipy.sparse.csr_matrix or lil_matrix): Current transition coupling matrix of shape (dx*dy, dx*dy). Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx). Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy). Returns: R (scipy.sparse.lil_matrix): Improved transition coupling matrix of shape (dx*dy, dx*dy). """ # Check if g is constant. dx, dy = Px.shape[0], Py.shape[0] R = sp.lil_matrix((dx * dy, dx * dy)) g_const = np.max(g) - np.min(g) <= 1e-3 # If g is not constant, improve transition coupling against g. if not g_const: R = setup_ot(g, Px, Py, R) if np.max(np.abs(R0.dot(g) - R.dot(g))) <= 1e-7: R = R0.copy() else: return R # Try to improve with respect to h. R = setup_ot(h, Px, Py, R) if np.max(np.abs(R0.dot(h) - R.dot(h))) <= 1e-4: R = R0.copy() return R