Coverage for bmm_multitask_learning/task_clustering/MultiTask_Base_Class.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-13 13:33 +0000

1import numpy as np 

2class MultiTaskNNBase: 

3 """Base class for all multi-task neural network variants""" 

4 

5 def __init__(self, n_input, n_hidden, n_tasks, activation='tanh'): 

6 """ 

7 Initialize base multi-task neural network 

8 

9 Args: 

10 n_input: Number of input features 

11 n_hidden: Number of hidden units 

12 n_tasks: Number of tasks 

13 activation: Activation function ('tanh' or 'linear') 

14 """ 

15 self.n_input = n_input 

16 self.n_hidden = n_hidden 

17 self.n_tasks = n_tasks 

18 self.activation = activation 

19 

20 # Initialize weights 

21 self._initialize_weights() 

22 

23 def _initialize_weights(self, scale=0.5): 

24 """Initialize network weights with given scale""" 

25 self.W = np.random.randn(self.n_hidden, self.n_input + 1) * scale 

26 self.A_map = [np.zeros(self.n_hidden + 1) for _ in range(self.n_tasks)] 

27 

28 def _activate(self, x): 

29 """Apply activation function to hidden units""" 

30 if self.activation == 'tanh': 

31 return np.tanh(x) 

32 elif self.activation == 'linear': 

33 return x 

34 else: 

35 raise ValueError("Activation must be 'tanh' or 'linear'") 

36 

37 def compute_hidden_activations(self, X): 

38 """Compute hidden unit activations with bias""" 

39 X_bias = np.hstack([X, np.ones((X.shape[0], 1))]) 

40 H = self._activate(np.dot(X_bias, self.W.T)) 

41 return np.hstack([H, np.ones((H.shape[0], 1))]) 

42 

43 def predict(self, X, task_idx): 

44 """Make predictions for a specific task""" 

45 H_bias = self.compute_hidden_activations(X) 

46 return np.dot(H_bias, self.A_map[task_idx]) 

47 

48 def compute_sufficient_statistics(self, X, y): 

49 """Compute sufficient statistics for a single task""" 

50 H_bias = self.compute_hidden_activations(X) 

51 return { 

52 'sum_hhT': np.dot(H_bias.T, H_bias), 

53 'sum_hy': np.dot(H_bias.T, y), 

54 'sum_yy': np.dot(y, y), 

55 'n_samples': X.shape[0] 

56 } 

57 

58 def _normalize_data(self, X_list, y_list): 

59 """Normalize input data""" 

60 X_norm = [(X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8) for X in X_list] 

61 y_norm = [(y - np.mean(y)) / (np.std(y) + 1e-8) for y in y_list] 

62 return X_norm, y_norm