Source code for pyotc.otc_backend.optimal_transport.native_refactor

"""Yuning's other other native implementation of lp ot"""

import numpy as np
from scipy.optimize import linprog
from typing import Any


[docs] def setup_rows(Aeq: np.ndarray, nx: int, ny: int) -> None: for row in range(nx): for t in range(ny): Aeq[row, (row * ny) + t] = 1 return None
[docs] def setup_columns(Aeq: np.ndarray, nx: int, ny: int) -> None: for row in range(nx): for t in range(ny): Aeq[row, (row * ny) + t] = 1 return None
[docs] def computeot_lp(C: np.ndarray, r: np.ndarray, c: np.ndarray) -> tuple[Any, Any]: """Compute optimal transport mapping via LP. Args: C (np.ndarray): cost r (np.ndarray): _description_ c (np.ndarray): _description_ Returns: tuple[Any, Any]: _description_ """ nx = r.size ny = c.size # setup LP Aeq = np.zeros((nx + ny, nx * ny)) beq = np.concatenate((r.flatten(), c.flatten())) beq = beq.reshape(-1, 1) setup_rows(Aeq, nx, ny) setup_columns(Aeq, nx, ny) cost = C.reshape(-1, 1) # Bound bound = [[0, None]] * (nx * ny) # Solve OT LP using linprog res = linprog(cost, A_eq=Aeq, b_eq=beq, bounds=bound, method="highs") lp_sol = res.x lp_val = res.fun return lp_sol, lp_val