Source code for teaspoon.TDA.Distance

"""
This module provides algorithms to compute pairwise distances between persistence diagrams.  The bottleneck distance and wasserstein distance are available.

"""

import numpy as np
from typing import AnyStr
import ot
from sklearn.metrics import pairwise_distances
import persim

"""
.. module: Distance
"""


[docs]def wassersteinDist( pts0: np.ndarray, pts1: np.ndarray, p: int = 2, q: int = 2, y_axis: AnyStr = "death", ) -> float: """ Compute the Persistant p-Wasserstein distance between the diagrams pts0, pts1 using optimal transport. Parameters ---------- pts0: array of shape (n_top_features, 2) The first persistence diagram pts1: array of shape (n_top_features, 2) The second persistence diagram y_axis: optional, default="death" What the y-axis of the diagram represents. Should be one of * ``"lifetime"`` * ``"death"`` p: int, optional (default=2) The p in the p-Wasserstein distance to compute q: 1, 2 or np.inf, optional (default = 2) The q for the internal distance between the points, L_q. Uses L_infty (Chebyshev) distance if q = np.inf. Currently not implemented for other q. Returns ------- distance: float The p-Wasserstein distance between diagrams ``pts0`` and ``pts1`` """ # Convert the diagram back to birth death coordinates if passed as birth, lifetime if y_axis == "lifetime": pts0[:, 1] = pts0[:, 0] + pts0[:, 1] pts1[:, 1] = pts1[:, 0] + pts1[:, 1] elif y_axis == 'death': pass else: raise ValueError("y_axis must be 'death' or 'lifetime'") # Check q. Eventually want to remove the q <=2 part. if type(q) == int and q >= 3: raise ValueError( "q (for the internal L_q) is currently only available for 1, 2, or np.inf") elif q == 1: # Distance to diagonal in L1 distance is just the lifetime extra_dist0 = (pts0[:, 1] - pts0[:, 0]) extra_dist1 = (pts1[:, 1] - pts1[:, 0]) elif (q >= 2): # Distance to diagonal in Lq distance # Closest point to (a,b) is at (x,x) for x = a + (b-a)/2 extra_dist0 = (pts0[:, 1] - pts0[:, 0]) * 2**(1/q - 1) extra_dist1 = (pts1[:, 1] - pts1[:, 0]) * 2**(1/q - 1) elif q == np.inf: extra_dist0 = (pts0[:, 1] - pts0[:, 0]) / 2 extra_dist1 = (pts1[:, 1] - pts1[:, 0]) / 2 else: raise ValueError("q must 1, 2, or np.inf") # Get distances between all pairs of off-diagonal points # When we fix this for more q options, if q == np.inf: metric = 'chebyshev' elif q == 1: metric = 'l1' elif q == 2: metric = 'l2' pairwise_dist = pairwise_distances(pts0, pts1, metric=metric) # Add a row and column corresponding to the distance to the diagonal all_pairs_ground_distance_a = np.hstack( [pairwise_dist, extra_dist0[:, np.newaxis]]) extra_row = np.zeros(all_pairs_ground_distance_a.shape[1]) extra_row[: pairwise_dist.shape[1]] = extra_dist1 all_pairs_ground_distance_a = np.vstack( [all_pairs_ground_distance_a, extra_row]) # Raise all distances to the pth power all_pairs_ground_distance_a = all_pairs_ground_distance_a ** p # Build vector representing the mass at each location # For n0 points in the first diagram and n1 in the second, # the total mass for each diagram is n0+n1. # The mass for all off diagonal points are 1, and # remaining weight is placed on the diagonal. n0 = pts0.shape[0] n1 = pts1.shape[0] a = np.ones(n0 + 1) a[n0] = n1 a = a / a.sum() b = np.ones(n1 + 1) b[n1] = n0 b = b / b.sum() # Get the distance according to optimal transport otDist = ot.emd2(a, b, all_pairs_ground_distance_a) # Multiply by the total mass and raise to the pth power out = np.power((n0 + n1) * otDist, 1.0 / p) return out
[docs]def bottleneckDist( pts0: np.ndarray, pts1: np.ndarray, matching=True, plot=True ): """ Compute the bottleneck distance between the diagrams pts0, pts1 using the persim package: https://persim.scikit-tda.org/en/latest/index.html Parameters ---------- pts0: array of shape (n_top_features, 2) The first persistence diagram pts1: array of shape (n_top_features, 2) The second persistence diagram matching: boolean, True returns matched array plot: boolean, True provides plot of matching points. Matching must be true for plot to return. Returns ------- distance: float The bottleneck distance between diagrams ``pts0`` and ``pts1`` """ if matching == True and plot == True: d, matching = persim.bottleneck(pts0, pts1, matching=matching) persim.bottleneck_matching(pts0, pts1, matching) return d, matching if matching == True and plot == False: d, matching = persim.bottleneck(pts0, pts1, matching=matching) return d, matching if matching == False and plot == True: raise Exception("Matching must be 'True' to enable plotting'") else: d = persim.bottleneck(pts0, pts1, matching=False) return d