ozan_min_norm_solvers.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import math
  2. import numpy as np
  3. import torch
  4. class MinNormSolver:
  5. MAX_ITER = 250
  6. STOP_CRIT = 1e-5
  7. def _min_norm_element_from2(v1v1, v1v2, v2v2):
  8. """
  9. Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
  10. d is the distance (objective) optimzed
  11. v1v1 = <x1,x1>
  12. v1v2 = <x1,x2>
  13. v2v2 = <x2,x2>
  14. """
  15. if v1v2 >= v1v1:
  16. # Case: Fig 1, third column
  17. gamma = 0.999
  18. cost = v1v1
  19. return gamma, cost
  20. if v1v2 >= v2v2:
  21. # Case: Fig 1, first column
  22. gamma = 0.001
  23. cost = v2v2
  24. return gamma, cost
  25. # Case: Fig 1, second column
  26. gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
  27. cost = v2v2 + gamma * (v1v2 - v2v2)
  28. return gamma, cost
  29. def _min_norm_2d(vecs, dps):
  30. """
  31. Find the minimum norm solution as combination of two points
  32. This is correct only in 2D
  33. ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
  34. """
  35. dmin = 1e99
  36. sol = None
  37. for i in range(len(vecs)):
  38. for j in range(i + 1, len(vecs)):
  39. if (i, j) not in dps:
  40. dps[(i, j)] = 0.0
  41. for k in range(len(vecs[i])):
  42. dps[(i, j)] += torch.dot(vecs[i][k], vecs[j][k]).item() # .data[0]
  43. dps[(j, i)] = dps[(i, j)]
  44. if (i, i) not in dps:
  45. dps[(i, i)] = 0.0
  46. for k in range(len(vecs[i])):
  47. dps[(i, i)] += torch.dot(vecs[i][k], vecs[i][k]).item() # .data[0]
  48. if (j, j) not in dps:
  49. dps[(j, j)] = 0.0
  50. for k in range(len(vecs[i])):
  51. dps[(j, j)] += torch.dot(vecs[j][k], vecs[j][k]).item() # .data[0]
  52. c, d = MinNormSolver._min_norm_element_from2(dps[(i, i)], dps[(i, j)], dps[(j, j)])
  53. # print('c,d',c,d)
  54. if d < dmin:
  55. dmin = d
  56. sol = [(i, j), c, d]
  57. if sol is None or math.isnan(c):
  58. raise ValueError('A numeric instability occured in ozan_min_norm_solvers.')
  59. return sol, dps
  60. def _projection2simplex(y):
  61. """
  62. Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
  63. """
  64. m = len(y)
  65. sorted_y = np.flip(np.sort(y), axis=0)
  66. tmpsum = 0.0
  67. tmax_f = (np.sum(y) - 1.0) / m
  68. for i in range(m - 1):
  69. tmpsum += sorted_y[i]
  70. tmax = (tmpsum - 1) / (i + 1.0)
  71. if tmax > sorted_y[i + 1]:
  72. tmax_f = tmax
  73. break
  74. return np.maximum(y - tmax_f, np.zeros(y.shape))
  75. def _next_point(cur_val, grad, n):
  76. proj_grad = grad - (np.sum(grad) / n)
  77. tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
  78. tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
  79. skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7)
  80. t = 1
  81. if len(tm1[tm1 > 1e-7]) > 0:
  82. t = np.min(tm1[tm1 > 1e-7])
  83. if len(tm2[tm2 > 1e-7]) > 0:
  84. t = min(t, np.min(tm2[tm2 > 1e-7]))
  85. next_point = proj_grad * t + cur_val
  86. next_point = MinNormSolver._projection2simplex(next_point)
  87. return next_point
  88. def find_min_norm_element(vecs):
  89. """
  90. Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
  91. as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
  92. It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
  93. Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
  94. """
  95. # Solution lying at the combination of two points
  96. dps = {}
  97. init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
  98. n = len(vecs)
  99. sol_vec = np.zeros(n)
  100. sol_vec[init_sol[0][0]] = init_sol[1]
  101. sol_vec[init_sol[0][1]] = 1 - init_sol[1]
  102. if n < 3:
  103. # This is optimal for n=2, so return the solution
  104. return sol_vec, init_sol[2]
  105. iter_count = 0
  106. grad_mat = np.zeros((n, n))
  107. for i in range(n):
  108. for j in range(n):
  109. grad_mat[i, j] = dps[(i, j)]
  110. while iter_count < MinNormSolver.MAX_ITER:
  111. grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
  112. new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
  113. # Re-compute the inner products for line search
  114. v1v1 = 0.0
  115. v1v2 = 0.0
  116. v2v2 = 0.0
  117. for i in range(n):
  118. for j in range(n):
  119. v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
  120. v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
  121. v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
  122. nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
  123. new_sol_vec = nc * sol_vec + (1 - nc) * new_point
  124. change = new_sol_vec - sol_vec
  125. if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
  126. return sol_vec, nd
  127. sol_vec = new_sol_vec
  128. def find_min_norm_element_FW(vecs):
  129. """
  130. Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
  131. as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
  132. It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
  133. Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
  134. """
  135. # Solution lying at the combination of two points
  136. dps = {}
  137. init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
  138. n = len(vecs)
  139. sol_vec = np.zeros(n)
  140. sol_vec[init_sol[0][0]] = init_sol[1]
  141. sol_vec[init_sol[0][1]] = 1 - init_sol[1]
  142. if n < 3:
  143. # This is optimal for n=2, so return the solution
  144. return sol_vec, init_sol[2]
  145. iter_count = 0
  146. grad_mat = np.zeros((n, n))
  147. for i in range(n):
  148. for j in range(n):
  149. grad_mat[i, j] = dps[(i, j)]
  150. while iter_count < MinNormSolver.MAX_ITER:
  151. t_iter = np.argmin(np.dot(grad_mat, sol_vec))
  152. v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
  153. v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
  154. v2v2 = grad_mat[t_iter, t_iter]
  155. nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
  156. new_sol_vec = nc * sol_vec
  157. new_sol_vec[t_iter] += 1 - nc
  158. change = new_sol_vec - sol_vec
  159. if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
  160. return sol_vec, nd
  161. sol_vec = new_sol_vec
  162. def gradient_normalizers(grads, losses, normalization_type):
  163. gn = {}
  164. if normalization_type == 'l2':
  165. for t in grads:
  166. gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
  167. elif normalization_type == 'loss':
  168. for t in grads:
  169. gn[t] = losses[t]
  170. elif normalization_type == 'loss+':
  171. for t in grads:
  172. gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
  173. elif normalization_type == 'none':
  174. for t in grads:
  175. gn[t] = 1.0
  176. else:
  177. print('ERROR: Invalid Normalization Type')
  178. return gn