Coverage for bmm_multitask_learning/task_clustering/MultiTask_Algo.py: 0%
479 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
1from MultiTask_Base_Class import MultiTaskNNBase
2import numpy as np
3from scipy.optimize import minimize
4from scipy.stats import multivariate_normal
5from scipy.special import logsumexp, softmax
6from sklearn.cluster import KMeans
7from sklearn.linear_model import LogisticRegression
8from scipy.linalg import solve_triangular, cholesky
9from sklearn.preprocessing import StandardScaler
10from tqdm import tqdm
12class MultiTaskNN(MultiTaskNNBase):
13 """Basic multi-task neural network with shared hidden layer and task-specific output weights"""
15 def __init__(self, n_input, n_hidden, n_tasks, activation='tanh'):
16 super().__init__(n_input, n_hidden, n_tasks, activation)
18 # Initialize hyperparameters
19 self.m = np.random.randn(n_hidden + 1) * 0.5
20 self.Sigma = np.eye(n_hidden + 1) * 0.5
21 self.sigma = 1.0
23 def log_likelihood(self, params, all_stats):
24 """Compute the log likelihood with numerical stability"""
25 try:
26 # Unpack parameters
27 param_idx = 0
29 # W
30 W_size = self.n_hidden * (self.n_input + 1)
31 W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
32 param_idx += W_size
34 # m
35 m_size = self.n_hidden + 1
36 m = params[param_idx:param_idx + m_size]
37 param_idx += m_size
39 # Sigma (Cholesky decomposition)
40 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
41 tril_indices = np.tril_indices(self.n_hidden + 1)
42 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
43 param_idx += len(tril_indices[0])
45 # sigma (log scale)
46 log_sigma = params[param_idx]
47 sigma = np.exp(log_sigma)
49 total_log_lik = 0.0
50 self.A_map = []
52 # Add regularization to Sigma
53 Sigma = np.dot(L, L.T) + 1e-6 * np.eye(self.n_hidden + 1)
55 # Precompute Sigma inverse using Cholesky
56 try:
57 L_sigma = cholesky(Sigma, lower=True)
58 Sigma_inv = solve_triangular(L_sigma, np.eye(self.n_hidden + 1), lower=True)
59 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
60 except np.linalg.LinAlgError:
61 return -np.inf
63 for stats in all_stats:
64 sum_hhT = stats['sum_hhT']
65 sum_hy = stats['sum_hy']
66 sum_yy = stats['sum_yy']
67 n_samples = stats['n_samples']
69 # Add small constant to avoid division by zero
70 sigma_sq = max(sigma ** 2, 1e-8)
72 # Compute Q_i with regularization
73 Q_i = (1.0 / sigma_sq) * sum_hhT + Sigma_inv
75 try:
76 L_Q = cholesky(Q_i + 1e-6 * np.eye(self.n_hidden + 1), lower=True)
77 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True)
78 Q_inv = np.dot(Q_inv.T, Q_inv)
79 except np.linalg.LinAlgError:
80 return -np.inf
82 R_i = (1.0 / sigma_sq) * sum_hy + np.dot(Sigma_inv, m)
84 # Compute MAP estimate with regularization
85 A_i = np.linalg.solve(Q_i + 1e-6 * np.eye(self.n_hidden + 1), R_i)
86 self.A_map.append(A_i)
88 # Compute log determinants
89 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
90 logdet_Sigma = 2 * np.sum(np.log(np.diag(L_sigma)))
92 # Compute log likelihood terms
93 term1 = -0.5 * (logdet_Sigma + n_samples * 2 * log_sigma + logdet_Q_i)
94 term2 = 0.5 * (
95 np.dot(R_i, np.dot(Q_inv, R_i)) - (1.0 / sigma_sq) * sum_yy - np.dot(m, np.dot(Sigma_inv, m)))
97 if not np.isfinite(term1 + term2):
98 return -np.inf
100 total_log_lik += term1 + term2
102 return total_log_lik if np.isfinite(total_log_lik) else -np.inf
104 except:
105 return -np.inf
107 def fit(self, X_list, y_list, max_iter=100):
108 """Fit the model to data"""
109 # Normalize data
110 X_list, y_list = self._normalize_data(X_list, y_list)
112 # Compute sufficient statistics
113 all_stats = [self.compute_sufficient_statistics(X, y) for X, y in zip(X_list, y_list)]
115 # Initial parameters with better scaling
116 initial_params = []
117 initial_params.extend(self.W.flatten())
118 initial_params.extend(self.m)
120 # Initialize Sigma with Cholesky decomposition
121 L = np.linalg.cholesky(self.Sigma + 1e-6 * np.eye(self.n_hidden + 1))
122 tril_indices = np.tril_indices(self.n_hidden + 1)
123 initial_params.extend(L[tril_indices])
125 initial_params.append(np.log(self.sigma))
127 # Optimize with bounds for stability
128 bounds = []
129 bounds.extend([(None, None)] * (self.n_hidden * (self.n_input + 1))) # W
130 bounds.extend([(None, None)] * (self.n_hidden + 1)) # m
132 # L - diagonal elements must be positive
133 for i in range(len(tril_indices[0])):
134 if tril_indices[0][i] == tril_indices[1][i]: # diagonal
135 bounds.append((1e-8, None))
136 else:
137 bounds.append((None, None))
139 bounds.append((np.log(1e-8), None)) # log_sigma
141 # Optimization with error handling
142 try:
143 result = minimize(
144 lambda p: -self.log_likelihood(p, all_stats),
145 initial_params,
146 method='L-BFGS-B',
147 bounds=bounds,
148 options={
149 'maxiter': max_iter,
150 'disp': True,
151 'maxfun': 15000,
152 'maxls': 50
153 }
154 )
156 # Store optimized parameters
157 self._unpack_parameters(result.x)
159 # Recompute MAP estimates
160 _ = self.log_likelihood(result.x, all_stats)
162 return result
164 except Exception as e:
165 print(f"Optimization failed: {str(e)}")
166 return None
168 def _unpack_parameters(self, params):
169 """Helper to unpack optimized parameters"""
170 param_idx = 0
172 # W
173 W_size = self.n_hidden * (self.n_input + 1)
174 self.W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
175 param_idx += W_size
177 # m
178 m_size = self.n_hidden + 1
179 self.m = params[param_idx:param_idx + m_size]
180 param_idx += m_size
182 # Sigma (Cholesky)
183 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
184 tril_indices = np.tril_indices(self.n_hidden + 1)
185 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
186 param_idx += len(tril_indices[0])
187 self.Sigma = np.dot(L, L.T) + 1e-6 * np.eye(self.n_hidden + 1)
189 # sigma
190 self.sigma = max(np.exp(params[param_idx]), 1e-8)
193class MultiTaskNNDependentMean(MultiTaskNNBase):
194 """Multi-task NN with task-dependent prior means"""
196 def __init__(self, n_input, n_hidden, n_tasks, n_features, activation='tanh'):
197 super().__init__(n_input, n_hidden, n_tasks, activation)
198 self.n_features = n_features
200 # Initialize hyperparameters with better scaling
201 self.M = np.random.randn(n_hidden + 1, n_features) * 0.1
202 self.Sigma = np.eye(n_hidden + 1) * 0.5 # Start with smaller variance
203 self.sigma = 1.0
205 def log_likelihood(self, params, all_stats, all_task_features):
206 """Compute the log likelihood with numerical stability improvements"""
207 # Unpack parameters
208 param_idx = 0
210 # W
211 W_size = self.n_hidden * (self.n_input + 1)
212 W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
213 param_idx += W_size
215 # M: (n_hidden + 1 x n_features)
216 M_size = (self.n_hidden + 1) * self.n_features
217 M = params[param_idx:param_idx + M_size].reshape(self.n_hidden + 1, self.n_features)
218 param_idx += M_size
220 # Sigma (Cholesky decomposition)
221 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
222 tril_indices = np.tril_indices(self.n_hidden + 1)
223 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
224 param_idx += len(tril_indices[0])
226 # sigma (log scale)
227 log_sigma = params[param_idx]
228 sigma = np.exp(log_sigma)
230 total_log_lik = 0.0
231 self.A_map = []
233 # Precompute Sigma inverse using Cholesky
234 try:
235 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
236 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
237 except np.linalg.LinAlgError:
238 return -np.inf # Invalid covariance matrix
240 for stats, task_features in zip(all_stats, all_task_features):
241 sum_hhT = stats['sum_hhT']
242 sum_hy = stats['sum_hy']
243 sum_yy = stats['sum_yy']
244 n_samples = stats['n_samples']
246 # Compute task-dependent prior mean
247 m_i = np.dot(M, task_features)
249 # Compute Q_i using Cholesky for stability
250 Q_i = (1.0 / (sigma ** 2)) * sum_hhT + Sigma_inv
252 try:
253 L_Q = np.linalg.cholesky(Q_i)
254 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True)
255 Q_inv = np.dot(Q_inv.T, Q_inv)
256 except np.linalg.LinAlgError:
257 return -np.inf
259 R_i = (1.0 / (sigma ** 2)) * sum_hy + np.dot(Sigma_inv, m_i)
261 # Compute MAP estimate
262 A_i = np.dot(Q_inv, R_i)
263 self.A_map.append(A_i)
265 # Compute log determinants efficiently
266 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
267 logdet_Sigma = 2 * np.sum(np.log(np.diag(L)))
269 # Compute log likelihood terms
270 term1 = -0.5 * (logdet_Sigma + n_samples * 2 * log_sigma + logdet_Q_i)
271 term2 = 0.5 * (np.dot(R_i, np.dot(Q_inv, R_i)) - (1.0 / (sigma ** 2)) * sum_yy - np.dot(m_i,
272 np.dot(Sigma_inv,
273 m_i)))
275 total_log_lik += term1 + term2
277 return total_log_lik
279 def fit(self, X_list, y_list, task_features_list, max_iter=100):
280 """Fit the model with improved optimization"""
281 # Normalize data
282 X_list = [(X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8) for X in X_list]
283 y_list = [(y - np.mean(y)) / (np.std(y) + 1e-8) for y in y_list]
285 # Compute sufficient statistics
286 all_stats = [self.compute_sufficient_statistics(X, y) for X, y in zip(X_list, y_list)]
288 # Initial parameters with better scaling
289 initial_params = []
290 initial_params.extend(self.W.flatten())
291 initial_params.extend(self.M.flatten())
293 L = np.linalg.cholesky(self.Sigma + 1e-6 * np.eye(self.n_hidden + 1))
294 tril_indices = np.tril_indices(self.n_hidden + 1)
295 initial_params.extend(L[tril_indices])
297 initial_params.append(np.log(self.sigma))
299 # Optimize with bounds for stability
300 bounds = []
302 # W - no bounds
303 bounds.extend([(None, None)] * (self.n_hidden * (self.n_input + 1)))
305 # M - no bounds
306 bounds.extend([(None, None)] * ((self.n_hidden + 1) * self.n_features))
308 # L - diagonal elements must be positive
309 for i in range(len(tril_indices[0])):
310 if tril_indices[0][i] == tril_indices[1][i]: # diagonal
311 bounds.append((1e-8, None))
312 else:
313 bounds.append((None, None))
315 # log_sigma must be > log(1e-8)
316 bounds.append((np.log(1e-8), None))
318 # Optimize
319 result = minimize(
320 lambda p: -self.log_likelihood(p, all_stats, task_features_list),
321 initial_params,
322 method='L-BFGS-B',
323 bounds=bounds,
324 options={'maxiter': max_iter, 'disp': True}
325 )
327 # Store optimized parameters
328 self._unpack_parameters(result.x)
330 # Recompute MAP estimates
331 _ = self.log_likelihood(result.x, all_stats, task_features_list)
333 return result
335 def _unpack_parameters(self, params):
336 """Helper to unpack optimized parameters"""
337 param_idx = 0
339 # W
340 W_size = self.n_hidden * (self.n_input + 1)
341 self.W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
342 param_idx += W_size
344 # M: (n_hidden + 1 x n_features)
345 M_size = (self.n_hidden + 1) * self.n_features
346 M = params[param_idx:param_idx + M_size].reshape(self.n_hidden + 1, self.n_features)
347 param_idx += M_size
349 # Sigma (Cholesky)
350 L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
351 tril_indices = np.tril_indices(self.n_hidden + 1)
352 L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
353 param_idx += len(tril_indices[0])
354 self.Sigma = np.dot(L, L.T)
356 # sigma
357 self.sigma = np.exp(params[param_idx])
360class MultiTaskNNClustering(MultiTaskNNBase):
361 """Multi-task NN with task clustering"""
363 def __init__(self, n_input, n_hidden, n_tasks, n_clusters, activation='tanh'):
364 super().__init__(n_input, n_hidden, n_tasks, activation)
365 self.n_clusters = n_clusters
367 # Initialize with larger scale and better conditioning
368 self.q = np.ones(n_clusters) / n_clusters
369 self.m = np.random.randn(n_clusters, n_hidden + 1) * 0.5
371 # Initialize Sigma with larger diagonal for numerical stability
372 self.Sigma = np.array([np.eye(n_hidden + 1) * 0.5 for _ in range(n_clusters)])
373 self.sigma = 1.0
375 self.z = np.zeros((n_tasks, n_clusters))
377 def _compute_task_log_likelihood(self, X_i, y_i, cluster_idx):
378 n_i = len(y_i)
379 h_i = self.compute_hidden_activations(X_i)
381 # Add small constant to avoid division by zero
382 sigma_sq = max(self.sigma ** 2, 1e-8)
384 try:
385 # Use Cholesky decomposition for numerical stability
386 L = cholesky(self.Sigma[cluster_idx], lower=True)
387 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
388 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
390 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
391 L_Q = cholesky(Q_i, lower=True)
392 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True)
393 Q_inv = np.dot(Q_inv.T, Q_inv)
395 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
397 # Compute log determinants efficiently
398 logdet_Sigma = 2 * np.sum(np.log(np.diag(L)))
399 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
401 term1 = -0.5 * (logdet_Sigma + n_i * np.log(sigma_sq) + logdet_Q_i)
402 term2 = 0.5 * (np.dot(R_i.T, np.dot(Q_inv, R_i)) - (1 / (2 * sigma_sq)) * np.sum(y_i ** 2) - np.dot(
403 self.m[cluster_idx].T, np.dot(Sigma_inv, self.m[cluster_idx])))
405 return term1 + term2
407 except np.linalg.LinAlgError:
408 # Return -inf if matrix is not positive definite
409 return -np.inf
411 def e_step(self, data):
412 log_responsibilities = np.zeros((self.n_tasks, self.n_clusters))
414 for i, (X_i, y_i) in enumerate(data):
415 for alpha in range(self.n_clusters):
416 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
417 log_responsibilities[i, alpha] = np.log(self.q[alpha] + 1e-8) + log_lik
419 # Normalize using logsumexp for numerical stability
420 log_responsibilities[i] -= logsumexp(log_responsibilities[i])
422 self.z = np.exp(log_responsibilities)
424 def m_step(self, data):
425 def objective(params):
426 W = params[:self.n_hidden * (self.n_input + 1)].reshape(self.n_hidden, self.n_input + 1)
427 log_sigma = params[-1]
428 sigma = np.exp(log_sigma)
430 self.W = W
431 self.sigma = max(sigma, 1e-8) # Prevent sigma from becoming too small
433 total_log_lik = 0.0
434 for i, (X_i, y_i) in enumerate(data):
435 for alpha in range(self.n_clusters):
436 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
437 total_log_lik += self.z[i, alpha] * log_lik
439 return -total_log_lik if np.isfinite(total_log_lik) else np.inf
441 # Initial parameters with bounds
442 initial_params = np.concatenate([
443 self.W.flatten(),
444 [np.log(self.sigma)]
445 ])
447 # Add bounds for sigma (log_sigma > log(1e-8))
448 bounds = [(None, None)] * len(initial_params)
449 bounds[-1] = (np.log(1e-8), None)
451 result = minimize(
452 objective,
453 initial_params,
454 method='L-BFGS-B',
455 bounds=bounds,
456 options={'maxiter': 50, 'disp': True}
457 )
459 opt_params = result.x
460 W_size = self.n_hidden * (self.n_input + 1)
461 self.W = opt_params[:W_size].reshape(self.n_hidden, self.n_input + 1)
462 self.sigma = max(np.exp(opt_params[-1]), 1e-8)
464 # Update cluster parameters with regularization
466 for alpha in range(self.n_clusters):
467 self.q[alpha] = max(np.sum(self.z[:, alpha]) / self.n_tasks, 1e-8)
469 sum_z = np.sum(self.z[:, alpha])
470 if sum_z > 1e-8:
471 weighted_R = np.zeros(self.n_hidden + 1)
472 weighted_Q = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
474 for i, (X_i, y_i) in enumerate(data):
475 h_i = self.compute_hidden_activations(X_i)
476 L = cholesky(self.Sigma[alpha] + 1e-6 * np.eye(self.n_hidden + 1), lower=True)
477 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
478 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
480 Q_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, h_i) + Sigma_inv
481 R_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[alpha])
483 weighted_R += self.z[i, alpha] * R_i
484 weighted_Q += self.z[i, alpha] * Q_i
486 try:
487 self.m[alpha] = np.linalg.solve(weighted_Q + 1e-6 * np.eye(self.n_hidden + 1), weighted_R)
488 except:
489 pass
491 # Update covariance with regularization
492 weighted_cov = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
493 for i, (X_i, y_i) in enumerate(data):
494 h_i = self.compute_hidden_activations(X_i)
495 A_i = self._compute_map_estimate(X_i, y_i, alpha)
496 diff = A_i - self.m[alpha]
497 weighted_cov += self.z[i, alpha] * np.outer(diff, diff)
499 self.Sigma[alpha] = weighted_cov / sum_z + 1e-6 * np.eye(self.n_hidden + 1)
501 def _compute_map_estimate(self, X_i, y_i, cluster_idx):
502 h_i = self.compute_hidden_activations(X_i)
503 sigma_sq = max(self.sigma ** 2, 1e-8)
505 L = cholesky(self.Sigma[cluster_idx] + 1e-6 * np.eye(self.n_hidden + 1), lower=True)
506 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
507 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
509 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
510 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
512 return np.linalg.solve(Q_i + 1e-6 * np.eye(self.n_hidden + 1), R_i)
514 def fit(self, data, max_iter=100, tol=1e-4):
515 prev_log_lik = -np.inf
517 for iteration in tqdm((range(max_iter))):
518 self.e_step(data)
519 self.m_step(data)
521 # Compute current log likelihood
522 current_log_lik = 0.0
523 for i, (X_i, y_i) in enumerate(data):
524 cluster_log_liks = []
525 for alpha in range(self.n_clusters):
526 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
527 cluster_log_liks.append(np.log(self.q[alpha] + 1e-8) + log_lik)
528 current_log_lik += logsumexp(cluster_log_liks)
530 if np.isnan(current_log_lik):
531 print("Warning: log likelihood is nan, stopping early")
532 break
534 if iteration > 0 and np.abs(current_log_lik - prev_log_lik) < tol:
535 print(f"Converged at iteration {iteration}")
536 break
538 prev_log_lik = current_log_lik
540 self._compute_final_weights(data)
542 def _compute_final_weights(self, data):
543 for i, (X_i, y_i) in enumerate(data):
544 most_likely_cluster = np.argmax(self.z[i])
545 self.A_map[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster)
547 def get_cluster_assignments(self):
548 return np.argmax(self.z, axis=1)
550 def get_task_similarity(self):
551 assignments = self.get_cluster_assignments()
552 return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments])
554 def _compute_final_weights(self, data):
555 for i, (X_i, y_i) in enumerate(data):
556 most_likely_cluster = np.argmax(self.z[i])
557 self.A_map[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster)
559 def get_cluster_assignments(self):
560 return np.argmax(self.z, axis=1)
562 def get_task_similarity(self):
563 assignments = self.get_cluster_assignments()
564 return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments])
566class MultiTaskNNGating(MultiTaskNNBase):
567 """Multi-task NN with gating network for task clustering"""
569 def __init__(self, n_input, n_hidden, n_tasks, n_clusters, n_features, activation='tanh'):
570 super().__init__(n_input, n_hidden, n_tasks, activation)
571 self.n_clusters = n_clusters
572 self.n_features = n_features
574 # Initialize with larger scale for better convergence
575 self.U = np.random.randn(n_clusters, n_features) * 0.5
576 self.m = np.random.randn(n_clusters, n_hidden + 1) * 0.5
578 # Initialize covariance matrices with larger diagonal
579 self.Sigma = np.array([np.eye(n_hidden + 1) * 0.5 for _ in range(n_clusters)])
580 self.sigma = 1.0
582 self.z = np.zeros((n_tasks, n_clusters))
584 def compute_gating_probabilities(self, F):
585 """Compute task-cluster assignment probabilities with numerical stability"""
586 # Ensure F is 2D array
587 F = np.atleast_2d(F)
588 if F.shape[0] == 1 and self.n_tasks > 1:
589 F = np.repeat(F, self.n_tasks, axis=0)
591 logits = np.dot(F, self.U.T)
592 return softmax(logits, axis=1)
594 def _compute_task_log_likelihood(self, X_i, y_i, cluster_idx):
595 n_i = len(y_i)
596 h_i = self.compute_hidden_activations(X_i)
598 # Add small constant to avoid division by zero
599 sigma_sq = max(self.sigma ** 2, 1e-8)
601 try:
602 # Use Cholesky decomposition for numerical stability
603 L = cholesky(self.Sigma[cluster_idx] + 1e-6 * np.eye(self.n_hidden + 1), lower=True)
604 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
605 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
607 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
608 L_Q = cholesky(Q_i, lower=True)
609 Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True)
610 Q_inv = np.dot(Q_inv.T, Q_inv)
612 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
614 # Compute log determinants
615 logdet_Sigma = 2 * np.sum(np.log(np.diag(L)))
616 logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
618 term1 = -0.5 * (logdet_Sigma + n_i * np.log(sigma_sq) + logdet_Q_i)
619 term2 = 0.5 * (np.dot(R_i.T, np.dot(Q_inv, R_i)) - (1 / (2 * sigma_sq)) * np.sum(y_i ** 2) - np.dot(
620 self.m[cluster_idx].T, np.dot(Sigma_inv, self.m[cluster_idx])))
622 return term1 + term2
624 except np.linalg.LinAlgError:
625 return -np.inf
627 def e_step(self, data, task_features):
628 """Expectation step with improved numerical stability"""
629 # Ensure task_features is 2D array
630 task_features = np.atleast_2d(task_features)
631 if task_features.shape[0] == 1 and self.n_tasks > 1:
632 task_features = np.repeat(task_features, self.n_tasks, axis=0)
634 q = self.compute_gating_probabilities(task_features)
635 log_responsibilities = np.zeros((self.n_tasks, self.n_clusters))
637 for i, (X_i, y_i) in enumerate(data):
638 for alpha in range(self.n_clusters):
639 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
640 log_responsibilities[i, alpha] = np.log(q[i, alpha] + 1e-8) + log_lik
642 # Normalize using logsumexp
643 log_responsibilities[i] -= logsumexp(log_responsibilities[i])
645 self.z = np.exp(log_responsibilities)
647 def m_step(self, data, task_features):
648 """Maximization step with regularization"""
650 # Optimize W and sigma
651 def objective(params):
652 W = params[:self.n_hidden * (self.n_input + 1)].reshape(self.n_hidden, self.n_input + 1)
653 log_sigma = params[-1]
654 sigma = np.exp(log_sigma)
656 self.W = W
657 self.sigma = max(sigma, 1e-8)
659 total_log_lik = 0.0
660 for i, (X_i, y_i) in enumerate(data):
661 for alpha in range(self.n_clusters):
662 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
663 total_log_lik += self.z[i, alpha] * log_lik
665 return -total_log_lik if np.isfinite(total_log_lik) else np.inf
667 # Initial parameters with bounds
668 initial_params = np.concatenate([
669 self.W.flatten(),
670 [np.log(self.sigma)]
671 ])
673 bounds = [(None, None)] * len(initial_params)
674 bounds[-1] = (np.log(1e-8), None) # sigma > 1e-8
676 result = minimize(
677 objective,
678 initial_params,
679 method='L-BFGS-B',
680 bounds=bounds,
681 options={'maxiter': 50, 'disp': True}
682 )
684 # Update parameters
685 opt_params = result.x
686 W_size = self.n_hidden * (self.n_input + 1)
687 self.W = opt_params[:W_size].reshape(self.n_hidden, self.n_input + 1)
688 self.sigma = max(np.exp(opt_params[-1]), 1e-8)
690 # Update cluster parameters with regularization
691 for alpha in range(self.n_clusters):
692 sum_z = np.sum(self.z[:, alpha])
693 if sum_z > 1e-8:
694 # Update m_α
695 weighted_R = np.zeros(self.n_hidden + 1)
696 weighted_Q = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
698 for i, (X_i, y_i) in enumerate(data):
699 h_i = self.compute_hidden_activations(X_i)
700 L = cholesky(self.Sigma[alpha] + 1e-6 * np.eye(self.n_hidden + 1), lower=True)
701 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
702 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
704 Q_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, h_i) + Sigma_inv
705 R_i = (1 / max(self.sigma ** 2, 1e-8)) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[alpha])
707 weighted_R += self.z[i, alpha] * R_i
708 weighted_Q += self.z[i, alpha] * Q_i
710 try:
711 self.m[alpha] = np.linalg.solve(weighted_Q + 1e-6 * np.eye(self.n_hidden + 1), weighted_R)
712 except:
713 pass
715 # Update Σ_α with regularization
716 weighted_cov = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
717 for i, (X_i, y_i) in enumerate(data):
718 h_i = self.compute_hidden_activations(X_i)
719 A_i = self._compute_map_estimate(X_i, y_i, alpha)
720 diff = A_i - self.m[alpha]
721 weighted_cov += self.z[i, alpha] * np.outer(diff, diff)
723 self.Sigma[alpha] = weighted_cov / sum_z + 1e-6 * np.eye(self.n_hidden + 1)
725 # Update gating parameters U
726 if self.n_clusters > 1:
727 task_features = np.atleast_2d(task_features)
728 if task_features.shape[0] == 1 and self.n_tasks > 1:
729 task_features = np.repeat(task_features, self.n_tasks, axis=0)
731 lr = LogisticRegression(
732 multi_class='multinomial',
733 solver='lbfgs',
734 fit_intercept=False,
735 max_iter=100,
736 penalty='l2',
737 C=1.0
738 )
739 try:
740 lr.fit(task_features, self.get_cluster_assignments(), sample_weight=np.max(self.z, axis=1))
741 self.U = lr.coef_
742 except:
743 pass
745 def _compute_map_estimate(self, X_i, y_i, cluster_idx):
746 h_i = self.compute_hidden_activations(X_i)
747 sigma_sq = max(self.sigma ** 2, 1e-8)
749 L = cholesky(self.Sigma[cluster_idx] + 1e-6 * np.eye(self.n_hidden + 1), lower=True)
750 Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
751 Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
753 Q_i = (1 / sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
754 R_i = (1 / sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
756 return np.linalg.solve(Q_i + 1e-6 * np.eye(self.n_hidden + 1), R_i)
758 def fit(self, data, task_features, max_iter=100, tol=1e-4):
759 """Improved fitting with better initialization and checks"""
760 prev_log_lik = -np.inf
762 # Normalize task features
763 task_features = np.atleast_2d(task_features)
764 if task_features.shape[0] == 1 and self.n_tasks > 1:
765 task_features = np.repeat(task_features, self.n_tasks, axis=0)
767 self.task_feature_mean = np.mean(task_features, axis=0)
768 self.task_feature_std = np.std(task_features, axis=0) + 1e-8
769 task_features = (task_features - self.task_feature_mean) / self.task_feature_std
771 for iteration in range(max_iter):
772 try:
773 self.e_step(data, task_features)
774 self.m_step(data, task_features)
776 # Compute current log likelihood
777 current_log_lik = 0.0
778 q = self.compute_gating_probabilities(task_features)
780 for i, (X_i, y_i) in enumerate(data):
781 cluster_log_liks = []
782 for alpha in range(self.n_clusters):
783 log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
784 cluster_log_liks.append(np.log(q[i, alpha] + 1e-8) + log_lik)
785 current_log_lik += logsumexp(cluster_log_liks)
787 if np.isnan(current_log_lik):
788 print("Warning: log likelihood is nan, stopping early")
789 break
791 if iteration > 0 and abs(current_log_lik - prev_log_lik) < tol:
792 print(f"Converged at iteration {iteration}")
793 break
795 prev_log_lik = current_log_lik
796 print(f"Iteration {iteration}, log likelihood: {current_log_lik}")
798 except Exception as e:
799 print(f"Error at iteration {iteration}: {str(e)}")
800 break
802 self._compute_final_weights(data)
803 return self
806 def _compute_final_weights(self, data):
807 for i, (X_i, y_i) in enumerate(data):
808 most_likely_cluster = np.argmax(self.z[i])
809 self.A_map[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster)
811 def get_cluster_assignments(self):
812 return np.argmax(self.z, axis=1)
815 def get_task_similarity(self):
816 assignments = self.get_cluster_assignments()
817 return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments])