Coverage for bmm_multitask_learning/task_clustering/MultiTask_Algo.py: 0%

479 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-13 13:33 +0000

1from MultiTask_Base_Class import MultiTaskNNBase 

2import numpy as np 

3from scipy.optimize import minimize 

4from scipy.stats import multivariate_normal 

5from scipy.special import logsumexp, softmax 

6from sklearn.cluster import KMeans 

7from sklearn.linear_model import LogisticRegression 

8from scipy.linalg import solve_triangular, cholesky 

9from sklearn.preprocessing import StandardScaler 

10from tqdm import tqdm 

11 

12class MultiTaskNN(MultiTaskNNBase): 

13 """Basic multi-task neural network with shared hidden layer and task-specific output weights""" 

14 

15 def __init__(self, n_input, n_hidden, n_tasks, activation='tanh'): 

16 super().__init__(n_input, n_hidden, n_tasks, activation) 

17 

18 # Initialize hyperparameters 

19 self.m = np.random.randn(n_hidden + 1) * 0.5 

20 self.Sigma = np.eye(n_hidden + 1) * 0.5 

21 self.sigma = 1.0 

22 

23 def log_likelihood(self, params, all_stats): 

24 """Compute the log likelihood with numerical stability""" 

25 try: 

26 # Unpack parameters 

27 param_idx = 0 

28 

29 # W 

30 W_size = self.n_hidden * (self.n_input + 1) 

31 W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1) 

32 param_idx += W_size 

33 

34 # m 

35 m_size = self.n_hidden + 1 

36 m = params[param_idx:param_idx + m_size] 

37 param_idx += m_size 

38 

39 # Sigma (Cholesky decomposition) 

40 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

41 tril_indices = np.tril_indices(self.n_hidden + 1) 

42 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])] 

43 param_idx += len(tril_indices[0]) 

44 

45 # sigma (log scale) 

46 log_sigma = params[param_idx] 

47 sigma = np.exp(log_sigma) 

48 

49 total_log_lik = 0.0 

50 self.A_map = [] 

51 

52 # Add regularization to Sigma 

53 Sigma = np.dot(L, L.T) + 1e-6 * np.eye(self.n_hidden + 1) 

54 

55 # Precompute Sigma inverse using Cholesky 

56 try: 

57 L_sigma = cholesky(Sigma, lower=True) 

58 Sigma_inv = solve_triangular(L_sigma, np.eye(self.n_hidden + 1), lower=True) 

59 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

60 except np.linalg.LinAlgError: 

61 return -np.inf 

62 

63 for stats in all_stats: 

64 sum_hhT = stats['sum_hhT'] 

65 sum_hy = stats['sum_hy'] 

66 sum_yy = stats['sum_yy'] 

67 n_samples = stats['n_samples'] 

68 

69 # Add small constant to avoid division by zero 

70 sigma_sq = max(sigma ** 2, 1e-8) 

71 

72 # Compute Q_i with regularization 

73 Q_i = (1.0 / sigma_sq) * sum_hhT + Sigma_inv 

74 

75 try: 

76 L_Q = cholesky(Q_i + 1e-6 * np.eye(self.n_hidden + 1), lower=True) 

77 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True) 

78 Q_inv = np.dot(Q_inv.T, Q_inv) 

79 except np.linalg.LinAlgError: 

80 return -np.inf 

81 

82 R_i = (1.0 / sigma_sq) * sum_hy + np.dot(Sigma_inv, m) 

83 

84 # Compute MAP estimate with regularization 

85 A_i = np.linalg.solve(Q_i + 1e-6 * np.eye(self.n_hidden + 1), R_i) 

86 self.A_map.append(A_i) 

87 

88 # Compute log determinants 

89 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q))) 

90 logdet_Sigma = 2 * np.sum(np.log(np.diag(L_sigma))) 

91 

92 # Compute log likelihood terms 

93 term1 = -0.5 * (logdet_Sigma + n_samples * 2 * log_sigma + logdet_Q_i) 

94 term2 = 0.5 * ( 

95 np.dot(R_i, np.dot(Q_inv, R_i)) - (1.0 / sigma_sq) * sum_yy - np.dot(m, np.dot(Sigma_inv, m))) 

96 

97 if not np.isfinite(term1 + term2): 

98 return -np.inf 

99 

100 total_log_lik += term1 + term2 

101 

102 return total_log_lik if np.isfinite(total_log_lik) else -np.inf 

103 

104 except: 

105 return -np.inf 

106 

107 def fit(self, X_list, y_list, max_iter=100): 

108 """Fit the model to data""" 

109 # Normalize data 

110 X_list, y_list = self._normalize_data(X_list, y_list) 

111 

112 # Compute sufficient statistics 

113 all_stats = [self.compute_sufficient_statistics(X, y) for X, y in zip(X_list, y_list)] 

114 

115 # Initial parameters with better scaling 

116 initial_params = [] 

117 initial_params.extend(self.W.flatten()) 

118 initial_params.extend(self.m) 

119 

120 # Initialize Sigma with Cholesky decomposition 

121 L = np.linalg.cholesky(self.Sigma + 1e-6 * np.eye(self.n_hidden + 1)) 

122 tril_indices = np.tril_indices(self.n_hidden + 1) 

123 initial_params.extend(L[tril_indices]) 

124 

125 initial_params.append(np.log(self.sigma)) 

126 

127 # Optimize with bounds for stability 

128 bounds = [] 

129 bounds.extend([(None, None)] * (self.n_hidden * (self.n_input + 1))) # W 

130 bounds.extend([(None, None)] * (self.n_hidden + 1)) # m 

131 

132 # L - diagonal elements must be positive 

133 for i in range(len(tril_indices[0])): 

134 if tril_indices[0][i] == tril_indices[1][i]: # diagonal 

135 bounds.append((1e-8, None)) 

136 else: 

137 bounds.append((None, None)) 

138 

139 bounds.append((np.log(1e-8), None)) # log_sigma 

140 

141 # Optimization with error handling 

142 try: 

143 result = minimize( 

144 lambda p: -self.log_likelihood(p, all_stats), 

145 initial_params, 

146 method='L-BFGS-B', 

147 bounds=bounds, 

148 options={ 

149 'maxiter': max_iter, 

150 'disp': True, 

151 'maxfun': 15000, 

152 'maxls': 50 

153 } 

154 ) 

155 

156 # Store optimized parameters 

157 self._unpack_parameters(result.x) 

158 

159 # Recompute MAP estimates 

160 _ = self.log_likelihood(result.x, all_stats) 

161 

162 return result 

163 

164 except Exception as e: 

165 print(f"Optimization failed: {str(e)}") 

166 return None 

167 

168 def _unpack_parameters(self, params): 

169 """Helper to unpack optimized parameters""" 

170 param_idx = 0 

171 

172 # W 

173 W_size = self.n_hidden * (self.n_input + 1) 

174 self.W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1) 

175 param_idx += W_size 

176 

177 # m 

178 m_size = self.n_hidden + 1 

179 self.m = params[param_idx:param_idx + m_size] 

180 param_idx += m_size 

181 

182 # Sigma (Cholesky) 

183 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

184 tril_indices = np.tril_indices(self.n_hidden + 1) 

185 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])] 

186 param_idx += len(tril_indices[0]) 

187 self.Sigma = np.dot(L, L.T) + 1e-6 * np.eye(self.n_hidden + 1) 

188 

189 # sigma 

190 self.sigma = max(np.exp(params[param_idx]), 1e-8) 

191 

192 

193class MultiTaskNNDependentMean(MultiTaskNNBase): 

194 """Multi-task NN with task-dependent prior means""" 

195 

196 def __init__(self, n_input, n_hidden, n_tasks, n_features, activation='tanh'): 

197 super().__init__(n_input, n_hidden, n_tasks, activation) 

198 self.n_features = n_features 

199 

200 # Initialize hyperparameters with better scaling 

201 self.M = np.random.randn(n_hidden + 1, n_features) * 0.1 

202 self.Sigma = np.eye(n_hidden + 1) * 0.5 # Start with smaller variance 

203 self.sigma = 1.0 

204 

205 def log_likelihood(self, params, all_stats, all_task_features): 

206 """Compute the log likelihood with numerical stability improvements""" 

207 # Unpack parameters 

208 param_idx = 0 

209 

210 # W 

211 W_size = self.n_hidden * (self.n_input + 1) 

212 W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1) 

213 param_idx += W_size 

214 

215 # M: (n_hidden + 1 x n_features) 

216 M_size = (self.n_hidden + 1) * self.n_features 

217 M = params[param_idx:param_idx + M_size].reshape(self.n_hidden + 1, self.n_features) 

218 param_idx += M_size 

219 

220 # Sigma (Cholesky decomposition) 

221 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

222 tril_indices = np.tril_indices(self.n_hidden + 1) 

223 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])] 

224 param_idx += len(tril_indices[0]) 

225 

226 # sigma (log scale) 

227 log_sigma = params[param_idx] 

228 sigma = np.exp(log_sigma) 

229 

230 total_log_lik = 0.0 

231 self.A_map = [] 

232 

233 # Precompute Sigma inverse using Cholesky 

234 try: 

235 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True) 

236 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

237 except np.linalg.LinAlgError: 

238 return -np.inf # Invalid covariance matrix 

239 

240 for stats, task_features in zip(all_stats, all_task_features): 

241 sum_hhT = stats['sum_hhT'] 

242 sum_hy = stats['sum_hy'] 

243 sum_yy = stats['sum_yy'] 

244 n_samples = stats['n_samples'] 

245 

246 # Compute task-dependent prior mean 

247 m_i = np.dot(M, task_features) 

248 

249 # Compute Q_i using Cholesky for stability 

250 Q_i = (1.0 / (sigma ** 2)) * sum_hhT + Sigma_inv 

251 

252 try: 

253 L_Q = np.linalg.cholesky(Q_i) 

254 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True) 

255 Q_inv = np.dot(Q_inv.T, Q_inv) 

256 except np.linalg.LinAlgError: 

257 return -np.inf 

258 

259 R_i = (1.0 / (sigma ** 2)) * sum_hy + np.dot(Sigma_inv, m_i) 

260 

261 # Compute MAP estimate 

262 A_i = np.dot(Q_inv, R_i) 

263 self.A_map.append(A_i) 

264 

265 # Compute log determinants efficiently 

266 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q))) 

267 logdet_Sigma = 2 * np.sum(np.log(np.diag(L))) 

268 

269 # Compute log likelihood terms 

270 term1 = -0.5 * (logdet_Sigma + n_samples * 2 * log_sigma + logdet_Q_i) 

271 term2 = 0.5 * (np.dot(R_i, np.dot(Q_inv, R_i)) - (1.0 / (sigma ** 2)) * sum_yy - np.dot(m_i, 

272 np.dot(Sigma_inv, 

273 m_i))) 

274 

275 total_log_lik += term1 + term2 

276 

277 return total_log_lik 

278 

279 def fit(self, X_list, y_list, task_features_list, max_iter=100): 

280 """Fit the model with improved optimization""" 

281 # Normalize data 

282 X_list = [(X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8) for X in X_list] 

283 y_list = [(y - np.mean(y)) / (np.std(y) + 1e-8) for y in y_list] 

284 

285 # Compute sufficient statistics 

286 all_stats = [self.compute_sufficient_statistics(X, y) for X, y in zip(X_list, y_list)] 

287 

288 # Initial parameters with better scaling 

289 initial_params = [] 

290 initial_params.extend(self.W.flatten()) 

291 initial_params.extend(self.M.flatten()) 

292 

293 L = np.linalg.cholesky(self.Sigma + 1e-6 * np.eye(self.n_hidden + 1)) 

294 tril_indices = np.tril_indices(self.n_hidden + 1) 

295 initial_params.extend(L[tril_indices]) 

296 

297 initial_params.append(np.log(self.sigma)) 

298 

299 # Optimize with bounds for stability 

300 bounds = [] 

301 

302 # W - no bounds 

303 bounds.extend([(None, None)] * (self.n_hidden * (self.n_input + 1))) 

304 

305 # M - no bounds 

306 bounds.extend([(None, None)] * ((self.n_hidden + 1) * self.n_features)) 

307 

308 # L - diagonal elements must be positive 

309 for i in range(len(tril_indices[0])): 

310 if tril_indices[0][i] == tril_indices[1][i]: # diagonal 

311 bounds.append((1e-8, None)) 

312 else: 

313 bounds.append((None, None)) 

314 

315 # log_sigma must be > log(1e-8) 

316 bounds.append((np.log(1e-8), None)) 

317 

318 # Optimize 

319 result = minimize( 

320 lambda p: -self.log_likelihood(p, all_stats, task_features_list), 

321 initial_params, 

322 method='L-BFGS-B', 

323 bounds=bounds, 

324 options={'maxiter': max_iter, 'disp': True} 

325 ) 

326 

327 # Store optimized parameters 

328 self._unpack_parameters(result.x) 

329 

330 # Recompute MAP estimates 

331 _ = self.log_likelihood(result.x, all_stats, task_features_list) 

332 

333 return result 

334 

335 def _unpack_parameters(self, params): 

336 """Helper to unpack optimized parameters""" 

337 param_idx = 0 

338 

339 # W 

340 W_size = self.n_hidden * (self.n_input + 1) 

341 self.W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1) 

342 param_idx += W_size 

343 

344 # M: (n_hidden + 1 x n_features) 

345 M_size = (self.n_hidden + 1) * self.n_features 

346 M = params[param_idx:param_idx + M_size].reshape(self.n_hidden + 1, self.n_features) 

347 param_idx += M_size 

348 

349 # Sigma (Cholesky) 

350 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

351 tril_indices = np.tril_indices(self.n_hidden + 1) 

352 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])] 

353 param_idx += len(tril_indices[0]) 

354 self.Sigma = np.dot(L, L.T) 

355 

356 # sigma 

357 self.sigma = np.exp(params[param_idx]) 

358 

359 

360class MultiTaskNNClustering(MultiTaskNNBase): 

361 """Multi-task NN with task clustering""" 

362 

363 def __init__(self, n_input, n_hidden, n_tasks, n_clusters, activation='tanh'): 

364 super().__init__(n_input, n_hidden, n_tasks, activation) 

365 self.n_clusters = n_clusters 

366 

367 # Initialize with larger scale and better conditioning 

368 self.q = np.ones(n_clusters) / n_clusters 

369 self.m = np.random.randn(n_clusters, n_hidden + 1) * 0.5 

370 

371 # Initialize Sigma with larger diagonal for numerical stability 

372 self.Sigma = np.array([np.eye(n_hidden + 1) * 0.5 for _ in range(n_clusters)]) 

373 self.sigma = 1.0 

374 

375 self.z = np.zeros((n_tasks, n_clusters)) 

376 

377 def _compute_task_log_likelihood(self, X_i, y_i, cluster_idx): 

378 n_i = len(y_i) 

379 h_i = self.compute_hidden_activations(X_i) 

380 

381 # Add small constant to avoid division by zero 

382 sigma_sq = max(self.sigma ** 2, 1e-8) 

383 

384 try: 

385 # Use Cholesky decomposition for numerical stability 

386 L = cholesky(self.Sigma[cluster_idx], lower=True) 

387 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True) 

388 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

389 

390 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv 

391 L_Q = cholesky(Q_i, lower=True) 

392 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True) 

393 Q_inv = np.dot(Q_inv.T, Q_inv) 

394 

395 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx]) 

396 

397 # Compute log determinants efficiently 

398 logdet_Sigma = 2 * np.sum(np.log(np.diag(L))) 

399 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q))) 

400 

401 term1 = -0.5 * (logdet_Sigma + n_i * np.log(sigma_sq) + logdet_Q_i) 

402 term2 = 0.5 * (np.dot(R_i.T, np.dot(Q_inv, R_i)) - (1 / (2 * sigma_sq)) * np.sum(y_i ** 2) - np.dot( 

403 self.m[cluster_idx].T, np.dot(Sigma_inv, self.m[cluster_idx]))) 

404 

405 return term1 + term2 

406 

407 except np.linalg.LinAlgError: 

408 # Return -inf if matrix is not positive definite 

409 return -np.inf 

410 

411 def e_step(self, data): 

412 log_responsibilities = np.zeros((self.n_tasks, self.n_clusters)) 

413 

414 for i, (X_i, y_i) in enumerate(data): 

415 for alpha in range(self.n_clusters): 

416 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha) 

417 log_responsibilities[i, alpha] = np.log(self.q[alpha] + 1e-8) + log_lik 

418 

419 # Normalize using logsumexp for numerical stability 

420 log_responsibilities[i] -= logsumexp(log_responsibilities[i]) 

421 

422 self.z = np.exp(log_responsibilities) 

423 

424 def m_step(self, data): 

425 def objective(params): 

426 W = params[:self.n_hidden * (self.n_input + 1)].reshape(self.n_hidden, self.n_input + 1) 

427 log_sigma = params[-1] 

428 sigma = np.exp(log_sigma) 

429 

430 self.W = W 

431 self.sigma = max(sigma, 1e-8) # Prevent sigma from becoming too small 

432 

433 total_log_lik = 0.0 

434 for i, (X_i, y_i) in enumerate(data): 

435 for alpha in range(self.n_clusters): 

436 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha) 

437 total_log_lik += self.z[i, alpha] * log_lik 

438 

439 return -total_log_lik if np.isfinite(total_log_lik) else np.inf 

440 

441 # Initial parameters with bounds 

442 initial_params = np.concatenate([ 

443 self.W.flatten(), 

444 [np.log(self.sigma)] 

445 ]) 

446 

447 # Add bounds for sigma (log_sigma > log(1e-8)) 

448 bounds = [(None, None)] * len(initial_params) 

449 bounds[-1] = (np.log(1e-8), None) 

450 

451 result = minimize( 

452 objective, 

453 initial_params, 

454 method='L-BFGS-B', 

455 bounds=bounds, 

456 options={'maxiter': 50, 'disp': True} 

457 ) 

458 

459 opt_params = result.x 

460 W_size = self.n_hidden * (self.n_input + 1) 

461 self.W = opt_params[:W_size].reshape(self.n_hidden, self.n_input + 1) 

462 self.sigma = max(np.exp(opt_params[-1]), 1e-8) 

463 

464 # Update cluster parameters with regularization 

465 

466 for alpha in range(self.n_clusters): 

467 self.q[alpha] = max(np.sum(self.z[:, alpha]) / self.n_tasks, 1e-8) 

468 

469 sum_z = np.sum(self.z[:, alpha]) 

470 if sum_z > 1e-8: 

471 weighted_R = np.zeros(self.n_hidden + 1) 

472 weighted_Q = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

473 

474 for i, (X_i, y_i) in enumerate(data): 

475 h_i = self.compute_hidden_activations(X_i) 

476 L = cholesky(self.Sigma[alpha] + 1e-6 * np.eye(self.n_hidden + 1), lower=True) 

477 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True) 

478 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

479 

480 Q_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, h_i) + Sigma_inv 

481 R_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[alpha]) 

482 

483 weighted_R += self.z[i, alpha] * R_i 

484 weighted_Q += self.z[i, alpha] * Q_i 

485 

486 try: 

487 self.m[alpha] = np.linalg.solve(weighted_Q + 1e-6 * np.eye(self.n_hidden + 1), weighted_R) 

488 except: 

489 pass 

490 

491 # Update covariance with regularization 

492 weighted_cov = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

493 for i, (X_i, y_i) in enumerate(data): 

494 h_i = self.compute_hidden_activations(X_i) 

495 A_i = self._compute_map_estimate(X_i, y_i, alpha) 

496 diff = A_i - self.m[alpha] 

497 weighted_cov += self.z[i, alpha] * np.outer(diff, diff) 

498 

499 self.Sigma[alpha] = weighted_cov / sum_z + 1e-6 * np.eye(self.n_hidden + 1) 

500 

501 def _compute_map_estimate(self, X_i, y_i, cluster_idx): 

502 h_i = self.compute_hidden_activations(X_i) 

503 sigma_sq = max(self.sigma ** 2, 1e-8) 

504 

505 L = cholesky(self.Sigma[cluster_idx] + 1e-6 * np.eye(self.n_hidden + 1), lower=True) 

506 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True) 

507 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

508 

509 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv 

510 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx]) 

511 

512 return np.linalg.solve(Q_i + 1e-6 * np.eye(self.n_hidden + 1), R_i) 

513 

514 def fit(self, data, max_iter=100, tol=1e-4): 

515 prev_log_lik = -np.inf 

516 

517 for iteration in tqdm((range(max_iter))): 

518 self.e_step(data) 

519 self.m_step(data) 

520 

521 # Compute current log likelihood 

522 current_log_lik = 0.0 

523 for i, (X_i, y_i) in enumerate(data): 

524 cluster_log_liks = [] 

525 for alpha in range(self.n_clusters): 

526 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha) 

527 cluster_log_liks.append(np.log(self.q[alpha] + 1e-8) + log_lik) 

528 current_log_lik += logsumexp(cluster_log_liks) 

529 

530 if np.isnan(current_log_lik): 

531 print("Warning: log likelihood is nan, stopping early") 

532 break 

533 

534 if iteration > 0 and np.abs(current_log_lik - prev_log_lik) < tol: 

535 print(f"Converged at iteration {iteration}") 

536 break 

537 

538 prev_log_lik = current_log_lik 

539 

540 self._compute_final_weights(data) 

541 

542 def _compute_final_weights(self, data): 

543 for i, (X_i, y_i) in enumerate(data): 

544 most_likely_cluster = np.argmax(self.z[i]) 

545 self.A_map[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster) 

546 

547 def get_cluster_assignments(self): 

548 return np.argmax(self.z, axis=1) 

549 

550 def get_task_similarity(self): 

551 assignments = self.get_cluster_assignments() 

552 return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments]) 

553 

554 def _compute_final_weights(self, data): 

555 for i, (X_i, y_i) in enumerate(data): 

556 most_likely_cluster = np.argmax(self.z[i]) 

557 self.A_map[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster) 

558 

559 def get_cluster_assignments(self): 

560 return np.argmax(self.z, axis=1) 

561 

562 def get_task_similarity(self): 

563 assignments = self.get_cluster_assignments() 

564 return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments]) 

565 

566class MultiTaskNNGating(MultiTaskNNBase): 

567 """Multi-task NN with gating network for task clustering""" 

568 

569 def __init__(self, n_input, n_hidden, n_tasks, n_clusters, n_features, activation='tanh'): 

570 super().__init__(n_input, n_hidden, n_tasks, activation) 

571 self.n_clusters = n_clusters 

572 self.n_features = n_features 

573 

574 # Initialize with larger scale for better convergence 

575 self.U = np.random.randn(n_clusters, n_features) * 0.5 

576 self.m = np.random.randn(n_clusters, n_hidden + 1) * 0.5 

577 

578 # Initialize covariance matrices with larger diagonal 

579 self.Sigma = np.array([np.eye(n_hidden + 1) * 0.5 for _ in range(n_clusters)]) 

580 self.sigma = 1.0 

581 

582 self.z = np.zeros((n_tasks, n_clusters)) 

583 

584 def compute_gating_probabilities(self, F): 

585 """Compute task-cluster assignment probabilities with numerical stability""" 

586 # Ensure F is 2D array 

587 F = np.atleast_2d(F) 

588 if F.shape[0] == 1 and self.n_tasks > 1: 

589 F = np.repeat(F, self.n_tasks, axis=0) 

590 

591 logits = np.dot(F, self.U.T) 

592 return softmax(logits, axis=1) 

593 

594 def _compute_task_log_likelihood(self, X_i, y_i, cluster_idx): 

595 n_i = len(y_i) 

596 h_i = self.compute_hidden_activations(X_i) 

597 

598 # Add small constant to avoid division by zero 

599 sigma_sq = max(self.sigma ** 2, 1e-8) 

600 

601 try: 

602 # Use Cholesky decomposition for numerical stability 

603 L = cholesky(self.Sigma[cluster_idx] + 1e-6 * np.eye(self.n_hidden + 1), lower=True) 

604 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True) 

605 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

606 

607 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv 

608 L_Q = cholesky(Q_i, lower=True) 

609 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True) 

610 Q_inv = np.dot(Q_inv.T, Q_inv) 

611 

612 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx]) 

613 

614 # Compute log determinants 

615 logdet_Sigma = 2 * np.sum(np.log(np.diag(L))) 

616 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q))) 

617 

618 term1 = -0.5 * (logdet_Sigma + n_i * np.log(sigma_sq) + logdet_Q_i) 

619 term2 = 0.5 * (np.dot(R_i.T, np.dot(Q_inv, R_i)) - (1 / (2 * sigma_sq)) * np.sum(y_i ** 2) - np.dot( 

620 self.m[cluster_idx].T, np.dot(Sigma_inv, self.m[cluster_idx]))) 

621 

622 return term1 + term2 

623 

624 except np.linalg.LinAlgError: 

625 return -np.inf 

626 

627 def e_step(self, data, task_features): 

628 """Expectation step with improved numerical stability""" 

629 # Ensure task_features is 2D array 

630 task_features = np.atleast_2d(task_features) 

631 if task_features.shape[0] == 1 and self.n_tasks > 1: 

632 task_features = np.repeat(task_features, self.n_tasks, axis=0) 

633 

634 q = self.compute_gating_probabilities(task_features) 

635 log_responsibilities = np.zeros((self.n_tasks, self.n_clusters)) 

636 

637 for i, (X_i, y_i) in enumerate(data): 

638 for alpha in range(self.n_clusters): 

639 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha) 

640 log_responsibilities[i, alpha] = np.log(q[i, alpha] + 1e-8) + log_lik 

641 

642 # Normalize using logsumexp 

643 log_responsibilities[i] -= logsumexp(log_responsibilities[i]) 

644 

645 self.z = np.exp(log_responsibilities) 

646 

647 def m_step(self, data, task_features): 

648 """Maximization step with regularization""" 

649 

650 # Optimize W and sigma 

651 def objective(params): 

652 W = params[:self.n_hidden * (self.n_input + 1)].reshape(self.n_hidden, self.n_input + 1) 

653 log_sigma = params[-1] 

654 sigma = np.exp(log_sigma) 

655 

656 self.W = W 

657 self.sigma = max(sigma, 1e-8) 

658 

659 total_log_lik = 0.0 

660 for i, (X_i, y_i) in enumerate(data): 

661 for alpha in range(self.n_clusters): 

662 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha) 

663 total_log_lik += self.z[i, alpha] * log_lik 

664 

665 return -total_log_lik if np.isfinite(total_log_lik) else np.inf 

666 

667 # Initial parameters with bounds 

668 initial_params = np.concatenate([ 

669 self.W.flatten(), 

670 [np.log(self.sigma)] 

671 ]) 

672 

673 bounds = [(None, None)] * len(initial_params) 

674 bounds[-1] = (np.log(1e-8), None) # sigma > 1e-8 

675 

676 result = minimize( 

677 objective, 

678 initial_params, 

679 method='L-BFGS-B', 

680 bounds=bounds, 

681 options={'maxiter': 50, 'disp': True} 

682 ) 

683 

684 # Update parameters 

685 opt_params = result.x 

686 W_size = self.n_hidden * (self.n_input + 1) 

687 self.W = opt_params[:W_size].reshape(self.n_hidden, self.n_input + 1) 

688 self.sigma = max(np.exp(opt_params[-1]), 1e-8) 

689 

690 # Update cluster parameters with regularization 

691 for alpha in range(self.n_clusters): 

692 sum_z = np.sum(self.z[:, alpha]) 

693 if sum_z > 1e-8: 

694 # Update m_α 

695 weighted_R = np.zeros(self.n_hidden + 1) 

696 weighted_Q = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

697 

698 for i, (X_i, y_i) in enumerate(data): 

699 h_i = self.compute_hidden_activations(X_i) 

700 L = cholesky(self.Sigma[alpha] + 1e-6 * np.eye(self.n_hidden + 1), lower=True) 

701 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True) 

702 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

703 

704 Q_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, h_i) + Sigma_inv 

705 R_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[alpha]) 

706 

707 weighted_R += self.z[i, alpha] * R_i 

708 weighted_Q += self.z[i, alpha] * Q_i 

709 

710 try: 

711 self.m[alpha] = np.linalg.solve(weighted_Q + 1e-6 * np.eye(self.n_hidden + 1), weighted_R) 

712 except: 

713 pass 

714 

715 # Update Σ_α with regularization 

716 weighted_cov = np.zeros((self.n_hidden + 1, self.n_hidden + 1)) 

717 for i, (X_i, y_i) in enumerate(data): 

718 h_i = self.compute_hidden_activations(X_i) 

719 A_i = self._compute_map_estimate(X_i, y_i, alpha) 

720 diff = A_i - self.m[alpha] 

721 weighted_cov += self.z[i, alpha] * np.outer(diff, diff) 

722 

723 self.Sigma[alpha] = weighted_cov / sum_z + 1e-6 * np.eye(self.n_hidden + 1) 

724 

725 # Update gating parameters U 

726 if self.n_clusters > 1: 

727 task_features = np.atleast_2d(task_features) 

728 if task_features.shape[0] == 1 and self.n_tasks > 1: 

729 task_features = np.repeat(task_features, self.n_tasks, axis=0) 

730 

731 lr = LogisticRegression( 

732 multi_class='multinomial', 

733 solver='lbfgs', 

734 fit_intercept=False, 

735 max_iter=100, 

736 penalty='l2', 

737 C=1.0 

738 ) 

739 try: 

740 lr.fit(task_features, self.get_cluster_assignments(), sample_weight=np.max(self.z, axis=1)) 

741 self.U = lr.coef_ 

742 except: 

743 pass 

744 

745 def _compute_map_estimate(self, X_i, y_i, cluster_idx): 

746 h_i = self.compute_hidden_activations(X_i) 

747 sigma_sq = max(self.sigma ** 2, 1e-8) 

748 

749 L = cholesky(self.Sigma[cluster_idx] + 1e-6 * np.eye(self.n_hidden + 1), lower=True) 

750 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True) 

751 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv) 

752 

753 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv 

754 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx]) 

755 

756 return np.linalg.solve(Q_i + 1e-6 * np.eye(self.n_hidden + 1), R_i) 

757 

758 def fit(self, data, task_features, max_iter=100, tol=1e-4): 

759 """Improved fitting with better initialization and checks""" 

760 prev_log_lik = -np.inf 

761 

762 # Normalize task features 

763 task_features = np.atleast_2d(task_features) 

764 if task_features.shape[0] == 1 and self.n_tasks > 1: 

765 task_features = np.repeat(task_features, self.n_tasks, axis=0) 

766 

767 self.task_feature_mean = np.mean(task_features, axis=0) 

768 self.task_feature_std = np.std(task_features, axis=0) + 1e-8 

769 task_features = (task_features - self.task_feature_mean) / self.task_feature_std 

770 

771 for iteration in range(max_iter): 

772 try: 

773 self.e_step(data, task_features) 

774 self.m_step(data, task_features) 

775 

776 # Compute current log likelihood 

777 current_log_lik = 0.0 

778 q = self.compute_gating_probabilities(task_features) 

779 

780 for i, (X_i, y_i) in enumerate(data): 

781 cluster_log_liks = [] 

782 for alpha in range(self.n_clusters): 

783 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha) 

784 cluster_log_liks.append(np.log(q[i, alpha] + 1e-8) + log_lik) 

785 current_log_lik += logsumexp(cluster_log_liks) 

786 

787 if np.isnan(current_log_lik): 

788 print("Warning: log likelihood is nan, stopping early") 

789 break 

790 

791 if iteration > 0 and abs(current_log_lik - prev_log_lik) < tol: 

792 print(f"Converged at iteration {iteration}") 

793 break 

794 

795 prev_log_lik = current_log_lik 

796 print(f"Iteration {iteration}, log likelihood: {current_log_lik}") 

797 

798 except Exception as e: 

799 print(f"Error at iteration {iteration}: {str(e)}") 

800 break 

801 

802 self._compute_final_weights(data) 

803 return self 

804 

805 

806 def _compute_final_weights(self, data): 

807 for i, (X_i, y_i) in enumerate(data): 

808 most_likely_cluster = np.argmax(self.z[i]) 

809 self.A_map[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster) 

810 

811 def get_cluster_assignments(self): 

812 return np.argmax(self.z, axis=1) 

813 

814 

815 def get_task_similarity(self): 

816 assignments = self.get_cluster_assignments() 

817 return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments]) 

818