Coverage for bmm_multitask_learning/sbmtl/sparse_bayesian_regression.py: 90%

228 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-13 13:33 +0000

1import torch 

2import torch.nn as nn 

3import torch.nn.functional as F 

4from torch import Tensor 

5from typing import List, Optional, Dict, Any 

6from scipy.optimize import root 

7from scipy.special import kv, psi 

8import numpy as np 

9 

10 

11class SparseBayesianRegression: 

12 def __init__(self, model: nn.Module, group_indices: List[List[int]], 

13 device: Optional[str] = None): 

14 """ 

15 model: torch.nn.Module (линейная или любая torch-модель) 

16 group_indices: список списков индексов параметров, соответствующих группам 

17 device: cpu/cuda 

18 """ 

19 self.model = model 

20 self.group_indices = group_indices 

21 self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") 

22 self.model.to(self.device) 

23 self._init_hyperparams() 

24 

25 def _init_hyperparams(self) -> None: 

26 """ 

27 Инициализирует гиперпараметры модели, включая параметры прайора и постериора. 

28 """ 

29 G = len(self.group_indices) 

30 # Общие (prior) параметры для всех групп 

31 self.omega_prior = torch.tensor(1.0, device=self.device, requires_grad=False) 

32 self.chi_prior = torch.tensor(1.0, device=self.device, requires_grad=False) 

33 self.phi_prior = torch.tensor(1.0, device=self.device, requires_grad=False) 

34 self.nu_prior = torch.tensor(1.0, device=self.device, requires_grad=False) # гиперпараметр для v_i 

35 # Постериорные параметры для каждой группы 

36 self.omega_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)] 

37 self.chi_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)] 

38 self.phi_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)] 

39 self.nu_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)] 

40 self.tau = torch.tensor(1.0, device=self.device, requires_grad=False) # дисперсия шума 

41 self.sigma2 = torch.tensor(1.0, device=self.device, requires_grad=False) # дисперсия шума 

42 self.K = 2 # ранг латентного пространства, можно параметризовать 

43 self.P = self.model(torch.zeros(1, self.model.in_features, device=self.device)).shape[-1] # количество выходов 

44 self.D = self.model.in_features # количество входов 

45 self.Omega_inv_g = [torch.eye(len(idxs), device=self.device) for idxs in self.group_indices] # [D_g] 

46 self.gammas = [1.0 for _ in range(G)] 

47 # Параметры постериорного распределения для W, Z, V 

48 # W: матрично-нормальное (M_W, Omega_W, S_W) 

49 self.M_W = torch.randn(self.P, self.D, device=self.device) / np.sqrt(self.P * self.D) 

50 self.Omega_W = torch.eye(self.D, device=self.device) 

51 self.S_W = torch.eye(self.P, device=self.device) 

52 # Z: список по группам (M_Z[g], Omega_Z[g], S_Z[g]) 

53 self.M_Z = [torch.randn(self.K, len(idxs), device=self.device) / np.sqrt(self.K * len(idxs)) for idxs in self.group_indices] 

54 self.Omega_Z = [torch.eye(len(idxs), device=self.device) for idxs in self.group_indices] 

55 self.S_Z = [torch.eye(self.K, device=self.device) for _ in self.group_indices] 

56 # V: матрично-нормальное (M_V, Omega_V, S_V) 

57 self.M_V = torch.zeros(self.P, self.K, device=self.device) 

58 self.Omega_V = torch.eye(self.K, device=self.device) 

59 self.S_V = torch.eye(self.P, device=self.device) 

60 

61 def _get_flat_params(self) -> Tensor: 

62 """ 

63 Вытягивает параметры модели в один вектор. 

64 Возвращает: 

65 Tensor: Вектор параметров модели. 

66 """ 

67 # Вытягивает параметры модели в один вектор 

68 return torch.cat([p.view(-1) for p in self.model.parameters()]) 

69 

70 def _set_flat_params(self, flat_params: Tensor) -> None: 

71 """ 

72 Устанавливает параметры модели из вектора. 

73 Аргументы: 

74 flat_params (Tensor): Вектор параметров модели. 

75 """ 

76 # Устанавливает параметры модели из вектора 

77 pointer = 0 

78 for p in self.model.parameters(): 

79 numel = p.numel() 

80 p.data.copy_(flat_params[pointer:pointer+numel].view_as(p)) 

81 pointer += numel 

82 @staticmethod 

83 def mean_gig(omega: float, chi: float, phi: float) -> float: 

84 """ 

85 Вычисляет математическое ожидание ⟨x⟩ для GIG(omega, chi, phi). 

86 ⟨x⟩ = sqrt(chi/phi) * R_omega(sqrt(chi*phi)) 

87 где R_omega(z) = K_{omega+1}(z) / K_{omega}(z) 

88 Аргументы: 

89 omega (float): Параметр omega распределения GIG. 

90 chi (float): Параметр chi распределения GIG. 

91 phi (float): Параметр phi распределения GIG. 

92 Возвращает: 

93 float: Математическое ожидание ⟨x⟩. 

94 """ 

95 z = (chi * phi) ** 0.5 

96 K_omega = kv(omega, z) 

97 K_omega_p1 = kv(omega + 1, z) 

98 R_omega = K_omega_p1 / K_omega if K_omega != 0 else 0.0 

99 return (chi / phi) ** 0.5 * R_omega 

100 

101 @staticmethod 

102 def mean_inv_gig(omega: float, chi: float, phi: float) -> float: 

103 """ 

104 Вычисляет математическое ожидание обратной величины ⟨1/x⟩ для GIG(omega, chi, phi). 

105 ⟨1/x⟩ = sqrt(chi/phi) * R_{omega-1}(sqrt(chi*phi)) 

106 где R_{omega-1}(z) = K_{omega}(z) / K_{omega-1}(z) 

107 Аргументы: 

108 omega (float): Параметр omega распределения GIG. 

109 chi (float): Параметр chi распределения GIG. 

110 phi (float): Параметр phi распределения GIG. 

111 Возвращает: 

112 float: Математическое ожидание ⟨1/x⟩. 

113 """ 

114 z = (chi * phi) ** 0.5 

115 K_omega = kv(omega, z) 

116 K_omega_m1 = kv(omega - 1, z) 

117 R_omega = K_omega_m1 / K_omega if K_omega != 0 else 0.0 

118 return (phi / chi) ** 0.5 * R_omega 

119 

120 @staticmethod 

121 def mean_log_gig(omega: float, chi: float, phi: float) -> float: 

122 """ 

123 Вычисляет математическое ожидание логарифма ⟨log(x)⟩ для GIG(omega, chi, phi). 

124 Аргументы: 

125 omega (float): Параметр omega распределения GIG. 

126 chi (float): Параметр chi распределения GIG. 

127 phi (float): Параметр phi распределения GIG. 

128 Возвращает: 

129 float: Математическое ожидание ⟨log(x)⟩. 

130 """ 

131 z = (chi * phi) ** 0.5 

132 return 0.5 * np.log(chi / phi) + (SparseBayesianRegression.d_log_bessel_k(omega, z)) 

133 

134 @staticmethod 

135 def d_log_bessel_k(omega: float, z: float) -> float: 

136 """ 

137 Вычисляет производную по omega от log K_omega(z). 

138 Аргументы: 

139 omega (float): Параметр omega. 

140 z (float): Параметр z. 

141 Возвращает: 

142 float: Значение производной. 

143 """ 

144 # Производная по omega от log K_omega(z) 

145 eps = 1e-5 

146 return (np.log(kv(omega + eps, z)) - np.log(kv(omega - eps, z))) / (2 * eps) 

147 

148 def update_gig_hyperparams(self, group_idx: int, mean_gamma: float, mean_inv_gamma: float, mean_log_gamma: float) -> None: 

149 """ 

150 Обновляет гиперпараметры GIG (omega, chi, phi) для одной группы. 

151 Аргументы: 

152 group_idx (int): Индекс группы. 

153 mean_gamma (float): Математическое ожидание ⟨gamma⟩. 

154 mean_inv_gamma (float): Математическое ожидание ⟨1/gamma⟩. 

155 mean_log_gamma (float): Математическое ожидание ⟨log(gamma)⟩. 

156 """ 

157 # Численно решает систему для omega, chi, phi для одной группы 

158 Q = 1 # для одной группы, если групп больше, можно обобщить 

159 def equations(params): 

160 omega, chi, phi = params 

161 z = np.sqrt(chi * phi) 

162 K_omega = kv(omega, z) 

163 d_logK = self.d_log_bessel_k(omega, z) 

164 R_omega = kv(omega + 1, z) / K_omega if K_omega != 0 else 0.0 

165 eq1 = Q * np.log(np.sqrt(phi / chi)) - Q * d_logK - Q * mean_log_gamma 

166 eq2 = (Q * omega) / chi - (Q / 2) * np.sqrt(phi / chi) * R_omega + 0.5 * mean_inv_gamma 

167 eq3 = (Q / np.sqrt(chi * phi)) * R_omega - mean_gamma 

168 return [eq1, eq2, eq3] 

169 # Начальные значения 

170 omega0 = float(self.omega_post[group_idx].cpu().numpy()) 

171 chi0 = float(self.chi_post[group_idx].cpu().numpy()) 

172 phi0 = float(self.phi_post[group_idx].cpu().numpy()) 

173 sol = root(equations, [omega0, chi0, phi0], method='hybr') 

174 if sol.success: 

175 self.omega_post[group_idx] = torch.tensor(sol.x[0], device=self.device) 

176 self.chi_post[group_idx] = torch.tensor(sol.x[1], device=self.device) 

177 self.phi_post[group_idx] = torch.tensor(sol.x[2], device=self.device) 

178 

179 def update_gig_prior(self, mean_gammas: List[float], mean_inv_gammas: List[float], mean_log_gammas: List[float]) -> None: 

180 """ 

181 Обновляет гиперпараметры GIG (omega, chi, phi) для общего прайора. 

182 Аргументы: 

183 mean_gammas (List[float]): Список математических ожиданий ⟨gamma⟩ для всех групп. 

184 mean_inv_gammas (List[float]): Список математических ожиданий ⟨1/gamma⟩ для всех групп. 

185 mean_log_gammas (List[float]): Список математических ожиданий ⟨log(gamma)⟩ для всех групп. 

186 """ 

187 # Численно решает систему для omega, chi, phi для общего прайора (по средним по группам) 

188 Q = len(mean_gammas) 

189 sum_gamma = torch.sum(torch.tensor(mean_gammas)) 

190 sum_inv_gamma = torch.sum(torch.tensor(mean_inv_gammas)) 

191 sum_log_gamma = torch.sum(torch.tensor(mean_log_gammas)) 

192 def equations(params): 

193 omega, chi, phi = params 

194 z = np.sqrt(chi * phi) 

195 K_omega = kv(omega, z) 

196 d_logK = self.d_log_bessel_k(omega, z) 

197 R_omega = kv(omega + 1, z) / K_omega if K_omega != 0 else 0.0 

198 eq1 = Q * np.log(np.sqrt(phi / chi)) - Q * d_logK * sum_log_gamma 

199 eq2 = (Q * omega) / chi - (Q / 2) * np.sqrt(phi / chi) * R_omega + 0.5 * sum_inv_gamma 

200 eq3 = Q * np.sqrt(chi/ phi) * R_omega - sum_gamma 

201 return [eq1, eq2, eq3] 

202 omega0 = float(self.omega_prior.cpu().numpy()) 

203 chi0 = float(self.chi_prior.cpu().numpy()) 

204 phi0 = float(self.phi_prior.cpu().numpy()) 

205 sol = root(equations, [omega0, chi0, phi0], method='hybr') 

206 if sol.success: 

207 self.omega_prior = torch.tensor(sol.x[0], device=self.device) 

208 self.chi_prior = torch.tensor(sol.x[1], device=self.device) 

209 self.phi_prior = torch.tensor(sol.x[2], device=self.device) 

210 

211 def compute_moments_W(self, M_W: Tensor, Omega_W: Tensor, S_W: Tensor) -> Tensor: 

212 """ 

213 Вычисляет момент ⟨W W^T⟩ для матрицы W. 

214 Аргументы: 

215 M_W (Tensor): Матрица средних значений W. 

216 Omega_W (Tensor): Ковариационная матрица по строкам W. 

217 S_W (Tensor): Ковариационная матрица по столбцам W. 

218 Возвращает: 

219 Tensor: Момент ⟨W W^T⟩. 

220 """ 

221 # Момент: E[W W^T] = M_W M_W^T + tr(S_W) * Omega_W 

222 return M_W @ M_W.t() + torch.trace(S_W) * Omega_W 

223 

224 def compute_moments_Z(self, M_Z: Tensor, Omega_Z: Tensor, S_Z: Tensor) -> Tensor: 

225 """ 

226 Вычисляет момент ⟨Z Z^T⟩ для матрицы Z. 

227 Аргументы: 

228 M_Z (Tensor): Матрица средних значений Z. 

229 Omega_Z (Tensor): Ковариационная матрица по строкам Z. 

230 S_Z (Tensor): Ковариационная матрица по столбцам Z. 

231 Возвращает: 

232 Tensor: Момент ⟨Z Z^T⟩. 

233 """ 

234 # Момент: E[Z Z^T] = M_Z M_Z^T + tr(S_Z) * Omega_Z 

235 return M_Z @ M_Z.t() + torch.trace(S_Z) * Omega_Z 

236 

237 def compute_moments_VVT(self, M_V: Tensor, Omega_V: Tensor, S_V: Tensor) -> Tensor: 

238 """ 

239 Вычисляет момент ⟨V V^T⟩ для матрицы V. 

240 Аргументы: 

241 M_V (Tensor): Матрица средних значений V. 

242 Omega_V (Tensor): Ковариационная матрица по строкам V. 

243 S_V (Tensor): Ковариационная матрица по столбцам V. 

244 Возвращает: 

245 Tensor: Момент ⟨V V^T⟩. 

246 """ 

247 # Момент: E[V V^T] = M_V M_V^T + tr(S_V) * Omega_V 

248 return M_V @ M_V.t() + torch.trace(Omega_V) * S_V 

249 def compute_moments_VTV(self, M_V: Tensor, Omega_V: Tensor, S_V: Tensor) -> Tensor: 

250 """ 

251 Вычисляет момент ⟨V^T V⟩ для матрицы V. 

252 Аргументы: 

253 M_V (Tensor): Матрица средних значений V. 

254 Omega_V (Tensor): Ковариационная матрица по строкам V. 

255 S_V (Tensor): Ковариационная матрица по столбцам V. 

256 Возвращает: 

257 Tensor: Момент ⟨V^T V⟩. 

258 """ 

259 # Момент: E[V V^T] = M_V M_V^T + tr(S_V) * Omega_V 

260 return M_V.T @ M_V + torch.trace(S_V) * Omega_V 

261 

262 def e_step(self, X: Tensor, Y: Tensor) -> Dict[str, Any]: 

263 """ 

264 E-шаг: координирует обновление всех параметров постериорных распределений и моментов. 

265 Аргументы: 

266 X (Tensor): матрица признаков (D, N) 

267 Y (Tensor): матрица откликов (P, N) 

268 Возвращает: dict[str, Any] — словарь с основными статистиками и параметрами для M-шагa. 

269 """ 

270 group_moments = self._compute_group_moments() 

271 self._update_posterior_matrices(X, Y, group_moments) 

272 self._update_posterior_wishart(group_moments) 

273 self._update_posterior_gig(group_moments) 

274 return { 

275 'mean_gammas': group_moments['mean_gammas'], 

276 'mean_inv_gammas': group_moments['mean_inv_gammas'], 

277 'mean_log_gammas': group_moments['mean_log_gammas'], 

278 "M_W": self.M_W, "Omega_W": self.Omega_W, "S_W": self.S_W, 

279 "M_Z": self.M_Z, "Omega_Z": self.Omega_Z, "S_Z": self.S_Z, 

280 "M_V": self.M_V, "Omega_V": self.Omega_V, "S_V": self.S_V 

281 } 

282 

283 def _compute_group_moments(self) -> Dict[str, Any]: 

284 """ 

285 Вычисляет моменты (средние значения) по всем группам, а также блочные матрицы Gamma и Omega_inv. 

286 Возвращает: dict[str, Any] — словарь с этими величинами. 

287 """ 

288 G = len(self.group_indices) 

289 D = self.D 

290 mean_gammas, mean_inv_gammas, mean_log_gammas = [], [], [] 

291 Gamma = torch.zeros(D, D, device=self.device) 

292 Omega_inv = torch.zeros(D, D, device=self.device) 

293 for g, idxs in enumerate(self.group_indices): 

294 omega = float(self.omega_post[g].cpu().numpy()) 

295 chi = float(self.chi_post[g].cpu().numpy()) 

296 phi = float(self.phi_post[g].cpu().numpy()) 

297 mg = self.mean_gig(omega, chi, phi) 

298 mig = self.mean_inv_gig(omega, chi, phi) 

299 mlg = self.mean_log_gig(omega, chi, phi) 

300 mean_gammas.append(mg) 

301 mean_inv_gammas.append(mig) 

302 mean_log_gammas.append(mlg) 

303 Gamma[idxs, :][:, idxs] = mg * torch.eye(len(idxs), device=self.device) 

304 Omega_inv[idxs, :][:, idxs] = self.Omega_inv_g[g] 

305 return { 

306 'mean_gammas': mean_gammas, 

307 'mean_inv_gammas': mean_inv_gammas, 

308 'mean_log_gammas': mean_log_gammas, 

309 'Gamma': Gamma, 

310 'Omega_inv': Omega_inv 

311 } 

312 

313 def _update_posterior_matrices(self, X: Tensor, Y: Tensor, group_moments: Dict[str, Any]) -> None: 

314 """ 

315 Обновляет параметры постериорных матрично-нормальных распределений W, Z, V на основе текущих моментов и данных. 

316 Аргументы: 

317 X (Tensor): матрица признаков 

318 Y (Tensor): матрица откликов 

319 group_moments (dict): словарь с моментами и блочными матрицами 

320 Возвращает: None 

321 """ 

322 Gamma = group_moments['Gamma'] 

323 Omega_inv = group_moments['Omega_inv'] 

324 tau = self.tau 

325 sigma2 = self.sigma2 

326 # Обновление W 

327 Omega_W_inv = (1.0 / tau) * Omega_inv @ Gamma + (1.0 / sigma2) * (X @ X.t()) 

328 self.Omega_W = torch.linalg.inv(Omega_W_inv) 

329 Z = torch.cat(self.M_Z, dim=1) # [K, D] 

330 self.M_W = ((1.0 / tau) * self.M_V @ Z @ Omega_inv @ Gamma + (1.0 / sigma2) * Y @ X.t()) @ self.Omega_W 

331 self.S_W = torch.eye(self.P, device=self.device) 

332 # Обновление Z 

333 self.M_Z, self.Omega_Z, self.S_Z = [], [], [] 

334 for g, idxs in enumerate(self.group_indices): 

335 Dg = len(idxs) 

336 Wg = self.M_W[:, idxs] # [P, Dg] 

337 moment_V = self.compute_moments_VTV(self.M_V, self.Omega_V, self.S_V) 

338 S_Zi = torch.linalg.inv((1.0 / tau) * moment_V + torch.eye(self.K, device=self.device)) 

339 M_Zi = (1.0 / tau) * S_Zi @ self.M_V.t() @ Wg 

340 Omega_Zi = (1.0 / group_moments['mean_gammas'][g]) * torch.linalg.inv(self.Omega_inv_g[g]) 

341 self.M_Z.append(M_Zi) 

342 self.Omega_Z.append(Omega_Zi) 

343 self.S_Z.append(S_Zi) 

344 # Обновление V 

345 Omega_V_inv = torch.zeros(self.K, self.K, device=self.device) 

346 for g, idxs in enumerate(self.group_indices): 

347 Omega_i_inv = self.Omega_inv_g[g] 

348 M_Zi = self.M_Z[g] 

349 S_Zi = self.S_Z[g] 

350 moment_Z = Omega_i_inv * torch.trace(S_Zi) + M_Zi @ Omega_i_inv @ M_Zi.t() 

351 Omega_V_inv += group_moments['mean_gammas'][g] * moment_Z 

352 Omega_V_inv += torch.eye(self.K, device=self.device) 

353 self.Omega_V = torch.linalg.inv(Omega_V_inv) 

354 self.M_V = self.M_W @ Omega_inv @ Gamma @ Z.t() @ self.Omega_V 

355 self.S_V = tau * torch.eye(self.P, device=self.device) 

356 

357 def _update_posterior_wishart(self, group_moments: Dict[str, Any]) -> None: 

358 """ 

359 Обновляет параметры постериорного распределения Wishart (Lambda, nu, Omega_inv_g) для каждой группы. 

360 Аргументы: 

361 group_moments (dict): словарь с моментами и блочными матрицами 

362 Возвращает: None 

363 """ 

364 self.Lambda, self.nu_post = [], [] 

365 for g, idxs in enumerate(self.group_indices): 

366 Dg = len(idxs) 

367 Wg = self.M_W[:, idxs] # [P, Dg] 

368 Zg = self.M_Z[g] # [K, Dg] 

369 resid = Wg - self.M_V @ Zg.T # [P, Dg] 

370 Omega_W_g = self.Omega_W[idxs, :][:, idxs] 

371 moment_resid_mean_part = resid.T @ resid 

372 Omega_i_inv = self.Omega_inv_g[g] 

373 D_W = torch.trace(self.S_W) * Omega_W_g 

374 moment_resid_disp_part = D_W + self.Omega_Z[g] * torch.sum(self.S_Z[g] * self.Omega_V) * torch.trace(self.S_V) + self.M_Z[g].T @ self.Omega_V @self.M_Z[g] * torch.trace(self.S_V)+\ 

375 self.Omega_Z[g] * torch.sum(self.S_Z[g] * (self.M_Z[g].T @ self.M_Z[g])) 

376 moment_resid = moment_resid_mean_part + moment_resid_disp_part 

377 moment_Z = self.Omega_Z[g] * torch.sum(self.S_Z[g] @ Omega_i_inv) + self.M_Z[g] @ Omega_i_inv @ self.M_Z[g].T 

378 Lambda_i = (1.0 / self.tau) * group_moments['mean_gammas'][g] * moment_resid + group_moments['mean_gammas'][g] * moment_Z + torch.eye(Dg, device=self.device) 

379 self.Lambda.append(Lambda_i) 

380 self.nu_post.append(float(self.nu_prior) + self.P + self.K) 

381 self.Omega_inv_g = [(Dg + self.nu_post[g] - 1) * torch.linalg.inv(self.Lambda[g]) for g, idxs in enumerate(self.group_indices)] 

382 

383 def _update_posterior_gig(self, group_moments: Dict[str, Any]) -> None: 

384 """ 

385 Обновляет параметры постериорного GIG (omega_post, chi_post, phi_post) для каждой группы. 

386 Аргументы: 

387 group_moments (dict): словарь с моментами и блочными матрицами 

388 Возвращает: None 

389 """ 

390 for g, idxs in enumerate(self.group_indices): 

391 Dg = len(idxs) 

392 # omega_post 

393 self.omega_post[g] = self.omega_prior + 0.5 * (self.P + self.K) * Dg 

394 # chi_post 

395 self.chi_post[g] = self.chi_prior 

396 # phi_post 

397 #Вспомогательные переменные 

398 Omega_i_inv = self.Omega_inv_g[g] 

399 Wg = self.M_W[:, idxs] # [P, Dg] 

400 Zg = self.M_Z[g] # [K, Dg] 

401 resid = Wg - self.M_V @ Zg 

402 Omega_W_g = self.Omega_W[idxs, :][:, idxs] 

403 Omega_inv_i = self.Omega_inv_g[g] 

404 #Подсчеты моментов 

405 moment_resid_mean_part = resid @ Omega_i_inv @ resid.t() 

406 moment_resid_W_disp_part = self.S_W * torch.sum(Omega_W_g * Omega_inv_i) 

407 moment_resid_VZ_disp_part = self.S_V * torch.sum(self.Omega_V * self.S_Z[g]) * torch.sum(self.Omega_Z[g]* Omega_i_inv) + self.M_V @ self.S_Z[g] @self.M_V.T * torch.sum(self.Omega_Z[g]* Omega_i_inv)+\ 

408 self.S_V * torch.sum(self.Omega_V * (self.M_Z[g] @ Omega_inv_i @self.M_Z[g].T)) 

409 moment_resid = moment_resid_mean_part + moment_resid_W_disp_part + moment_resid_VZ_disp_part 

410 tr_resid = torch.trace(moment_resid) 

411 # Момент: <Z_i Omega_i^{-1} Z_i^T> 

412 moment_Z = Zg @ Omega_i_inv @ Zg.t() + torch.trace(self.S_Z[g]) * Omega_i_inv 

413 tr_Z = torch.trace(moment_Z) 

414 self.phi_post[g] = self.phi_prior + (1.0 / self.tau) * tr_resid + tr_Z 

415 def m_step(self, post: Dict[str, Any]) -> None: 

416 """ 

417 M-шаг: обновляет гиперпараметры модели на основе результатов E-шагa. 

418  

419 Аргументы: 

420 post (Dict[str, Any]): Словарь с результатами E-шагa, включая моменты и параметры постериорного распределения. 

421 """ 

422 # Теперь только обновление общего прайора 

423 self.update_gig_prior(post['mean_gammas'], post['mean_inv_gammas'], post['mean_log_gammas']) 

424 # Обновление tau (дисперсия шума) 

425 # ... (см. 4.2, зависит от задачи) 

426 def fit(self, X: Tensor, Y: Tensor, num_iter: int = 10) -> None: 

427 """ 

428 Обучает модель с использованием EM-алгоритма. 

429  

430 Аргументы: 

431 X (Tensor): Матрица признаков (N, D). 

432 Y (Tensor): Матрица откликов (N, P). 

433 num_iter (int): Количество итераций EM-алгоритма. 

434 """ 

435 X = X.T 

436 for _ in range(num_iter): 

437 post = self.e_step(X, Y) 

438 self.m_step(post) 

439 self._set_flat_params(post["M_W"]) 

440 

441 def predict(self, X: Tensor) -> Tensor: 

442 """ 

443 Выполняет предсказание на основе обученной модели. 

444  

445 Аргументы: 

446 X (Tensor): Матрица признаков (N, D). 

447  

448 Возвращает: 

449 Tensor: Предсказания модели (N, P). 

450 """ 

451 self.model.eval() 

452 with torch.no_grad(): 

453 return self.model(X).T