Coverage for bmm_multitask_learning/task_clustering/MultiTask_Base_Class.py: 0%
31 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
1import numpy as np
2class MultiTaskNNBase:
3 """Base class for all multi-task neural network variants"""
5 def __init__(self, n_input, n_hidden, n_tasks, activation='tanh'):
6 """
7 Initialize base multi-task neural network
9 Args:
10 n_input: Number of input features
11 n_hidden: Number of hidden units
12 n_tasks: Number of tasks
13 activation: Activation function ('tanh' or 'linear')
14 """
15 self.n_input = n_input
16 self.n_hidden = n_hidden
17 self.n_tasks = n_tasks
18 self.activation = activation
20 # Initialize weights
21 self._initialize_weights()
23 def _initialize_weights(self, scale=0.5):
24 """Initialize network weights with given scale"""
25 self.W = np.random.randn(self.n_hidden, self.n_input + 1) * scale
26 self.A_map = [np.zeros(self.n_hidden + 1) for _ in range(self.n_tasks)]
28 def _activate(self, x):
29 """Apply activation function to hidden units"""
30 if self.activation == 'tanh':
31 return np.tanh(x)
32 elif self.activation == 'linear':
33 return x
34 else:
35 raise ValueError("Activation must be 'tanh' or 'linear'")
37 def compute_hidden_activations(self, X):
38 """Compute hidden unit activations with bias"""
39 X_bias = np.hstack([X, np.ones((X.shape[0], 1))])
40 H = self._activate(np.dot(X_bias, self.W.T))
41 return np.hstack([H, np.ones((H.shape[0], 1))])
43 def predict(self, X, task_idx):
44 """Make predictions for a specific task"""
45 H_bias = self.compute_hidden_activations(X)
46 return np.dot(H_bias, self.A_map[task_idx])
48 def compute_sufficient_statistics(self, X, y):
49 """Compute sufficient statistics for a single task"""
50 H_bias = self.compute_hidden_activations(X)
51 return {
52 'sum_hhT': np.dot(H_bias.T, H_bias),
53 'sum_hy': np.dot(H_bias.T, y),
54 'sum_yy': np.dot(y, y),
55 'n_samples': X.shape[0]
56 }
58 def _normalize_data(self, X_list, y_list):
59 """Normalize input data"""
60 X_norm = [(X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8) for X in X_list]
61 y_norm = [(y - np.mean(y)) / (np.std(y) + 1e-8) for y in y_list]
62 return X_norm, y_norm