Coverage for bmm_multitask_learning/sbmtl/sparse_bayesian_regression.py: 90%
228 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch import Tensor
5from typing import List, Optional, Dict, Any
6from scipy.optimize import root
7from scipy.special import kv, psi
8import numpy as np
11class SparseBayesianRegression:
12 def __init__(self, model: nn.Module, group_indices: List[List[int]],
13 device: Optional[str] = None):
14 """
15 model: torch.nn.Module (линейная или любая torch-модель)
16 group_indices: список списков индексов параметров, соответствующих группам
17 device: cpu/cuda
18 """
19 self.model = model
20 self.group_indices = group_indices
21 self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
22 self.model.to(self.device)
23 self._init_hyperparams()
25 def _init_hyperparams(self) -> None:
26 """
27 Инициализирует гиперпараметры модели, включая параметры прайора и постериора.
28 """
29 G = len(self.group_indices)
30 # Общие (prior) параметры для всех групп
31 self.omega_prior = torch.tensor(1.0, device=self.device, requires_grad=False)
32 self.chi_prior = torch.tensor(1.0, device=self.device, requires_grad=False)
33 self.phi_prior = torch.tensor(1.0, device=self.device, requires_grad=False)
34 self.nu_prior = torch.tensor(1.0, device=self.device, requires_grad=False) # гиперпараметр для v_i
35 # Постериорные параметры для каждой группы
36 self.omega_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)]
37 self.chi_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)]
38 self.phi_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)]
39 self.nu_post = [torch.tensor(1.0, device=self.device, requires_grad=False) for _ in range(G)]
40 self.tau = torch.tensor(1.0, device=self.device, requires_grad=False) # дисперсия шума
41 self.sigma2 = torch.tensor(1.0, device=self.device, requires_grad=False) # дисперсия шума
42 self.K = 2 # ранг латентного пространства, можно параметризовать
43 self.P = self.model(torch.zeros(1, self.model.in_features, device=self.device)).shape[-1] # количество выходов
44 self.D = self.model.in_features # количество входов
45 self.Omega_inv_g = [torch.eye(len(idxs), device=self.device) for idxs in self.group_indices] # [D_g]
46 self.gammas = [1.0 for _ in range(G)]
47 # Параметры постериорного распределения для W, Z, V
48 # W: матрично-нормальное (M_W, Omega_W, S_W)
49 self.M_W = torch.randn(self.P, self.D, device=self.device) / np.sqrt(self.P * self.D)
50 self.Omega_W = torch.eye(self.D, device=self.device)
51 self.S_W = torch.eye(self.P, device=self.device)
52 # Z: список по группам (M_Z[g], Omega_Z[g], S_Z[g])
53 self.M_Z = [torch.randn(self.K, len(idxs), device=self.device) / np.sqrt(self.K * len(idxs)) for idxs in self.group_indices]
54 self.Omega_Z = [torch.eye(len(idxs), device=self.device) for idxs in self.group_indices]
55 self.S_Z = [torch.eye(self.K, device=self.device) for _ in self.group_indices]
56 # V: матрично-нормальное (M_V, Omega_V, S_V)
57 self.M_V = torch.zeros(self.P, self.K, device=self.device)
58 self.Omega_V = torch.eye(self.K, device=self.device)
59 self.S_V = torch.eye(self.P, device=self.device)
61 def _get_flat_params(self) -> Tensor:
62 """
63 Вытягивает параметры модели в один вектор.
64 Возвращает:
65 Tensor: Вектор параметров модели.
66 """
67 # Вытягивает параметры модели в один вектор
68 return torch.cat([p.view(-1) for p in self.model.parameters()])
70 def _set_flat_params(self, flat_params: Tensor) -> None:
71 """
72 Устанавливает параметры модели из вектора.
73 Аргументы:
74 flat_params (Tensor): Вектор параметров модели.
75 """
76 # Устанавливает параметры модели из вектора
77 pointer = 0
78 for p in self.model.parameters():
79 numel = p.numel()
80 p.data.copy_(flat_params[pointer:pointer+numel].view_as(p))
81 pointer += numel
82 @staticmethod
83 def mean_gig(omega: float, chi: float, phi: float) -> float:
84 """
85 Вычисляет математическое ожидание ⟨x⟩ для GIG(omega, chi, phi).
86 ⟨x⟩ = sqrt(chi/phi) * R_omega(sqrt(chi*phi))
87 где R_omega(z) = K_{omega+1}(z) / K_{omega}(z)
88 Аргументы:
89 omega (float): Параметр omega распределения GIG.
90 chi (float): Параметр chi распределения GIG.
91 phi (float): Параметр phi распределения GIG.
92 Возвращает:
93 float: Математическое ожидание ⟨x⟩.
94 """
95 z = (chi * phi) ** 0.5
96 K_omega = kv(omega, z)
97 K_omega_p1 = kv(omega + 1, z)
98 R_omega = K_omega_p1 / K_omega if K_omega != 0 else 0.0
99 return (chi / phi) ** 0.5 * R_omega
101 @staticmethod
102 def mean_inv_gig(omega: float, chi: float, phi: float) -> float:
103 """
104 Вычисляет математическое ожидание обратной величины ⟨1/x⟩ для GIG(omega, chi, phi).
105 ⟨1/x⟩ = sqrt(chi/phi) * R_{omega-1}(sqrt(chi*phi))
106 где R_{omega-1}(z) = K_{omega}(z) / K_{omega-1}(z)
107 Аргументы:
108 omega (float): Параметр omega распределения GIG.
109 chi (float): Параметр chi распределения GIG.
110 phi (float): Параметр phi распределения GIG.
111 Возвращает:
112 float: Математическое ожидание ⟨1/x⟩.
113 """
114 z = (chi * phi) ** 0.5
115 K_omega = kv(omega, z)
116 K_omega_m1 = kv(omega - 1, z)
117 R_omega = K_omega_m1 / K_omega if K_omega != 0 else 0.0
118 return (phi / chi) ** 0.5 * R_omega
120 @staticmethod
121 def mean_log_gig(omega: float, chi: float, phi: float) -> float:
122 """
123 Вычисляет математическое ожидание логарифма ⟨log(x)⟩ для GIG(omega, chi, phi).
124 Аргументы:
125 omega (float): Параметр omega распределения GIG.
126 chi (float): Параметр chi распределения GIG.
127 phi (float): Параметр phi распределения GIG.
128 Возвращает:
129 float: Математическое ожидание ⟨log(x)⟩.
130 """
131 z = (chi * phi) ** 0.5
132 return 0.5 * np.log(chi / phi) + (SparseBayesianRegression.d_log_bessel_k(omega, z))
134 @staticmethod
135 def d_log_bessel_k(omega: float, z: float) -> float:
136 """
137 Вычисляет производную по omega от log K_omega(z).
138 Аргументы:
139 omega (float): Параметр omega.
140 z (float): Параметр z.
141 Возвращает:
142 float: Значение производной.
143 """
144 # Производная по omega от log K_omega(z)
145 eps = 1e-5
146 return (np.log(kv(omega + eps, z)) - np.log(kv(omega - eps, z))) / (2 * eps)
148 def update_gig_hyperparams(self, group_idx: int, mean_gamma: float, mean_inv_gamma: float, mean_log_gamma: float) -> None:
149 """
150 Обновляет гиперпараметры GIG (omega, chi, phi) для одной группы.
151 Аргументы:
152 group_idx (int): Индекс группы.
153 mean_gamma (float): Математическое ожидание ⟨gamma⟩.
154 mean_inv_gamma (float): Математическое ожидание ⟨1/gamma⟩.
155 mean_log_gamma (float): Математическое ожидание ⟨log(gamma)⟩.
156 """
157 # Численно решает систему для omega, chi, phi для одной группы
158 Q = 1 # для одной группы, если групп больше, можно обобщить
159 def equations(params):
160 omega, chi, phi = params
161 z = np.sqrt(chi * phi)
162 K_omega = kv(omega, z)
163 d_logK = self.d_log_bessel_k(omega, z)
164 R_omega = kv(omega + 1, z) / K_omega if K_omega != 0 else 0.0
165 eq1 = Q * np.log(np.sqrt(phi / chi)) - Q * d_logK - Q * mean_log_gamma
166 eq2 = (Q * omega) / chi - (Q / 2) * np.sqrt(phi / chi) * R_omega + 0.5 * mean_inv_gamma
167 eq3 = (Q / np.sqrt(chi * phi)) * R_omega - mean_gamma
168 return [eq1, eq2, eq3]
169 # Начальные значения
170 omega0 = float(self.omega_post[group_idx].cpu().numpy())
171 chi0 = float(self.chi_post[group_idx].cpu().numpy())
172 phi0 = float(self.phi_post[group_idx].cpu().numpy())
173 sol = root(equations, [omega0, chi0, phi0], method='hybr')
174 if sol.success:
175 self.omega_post[group_idx] = torch.tensor(sol.x[0], device=self.device)
176 self.chi_post[group_idx] = torch.tensor(sol.x[1], device=self.device)
177 self.phi_post[group_idx] = torch.tensor(sol.x[2], device=self.device)
179 def update_gig_prior(self, mean_gammas: List[float], mean_inv_gammas: List[float], mean_log_gammas: List[float]) -> None:
180 """
181 Обновляет гиперпараметры GIG (omega, chi, phi) для общего прайора.
182 Аргументы:
183 mean_gammas (List[float]): Список математических ожиданий ⟨gamma⟩ для всех групп.
184 mean_inv_gammas (List[float]): Список математических ожиданий ⟨1/gamma⟩ для всех групп.
185 mean_log_gammas (List[float]): Список математических ожиданий ⟨log(gamma)⟩ для всех групп.
186 """
187 # Численно решает систему для omega, chi, phi для общего прайора (по средним по группам)
188 Q = len(mean_gammas)
189 sum_gamma = torch.sum(torch.tensor(mean_gammas))
190 sum_inv_gamma = torch.sum(torch.tensor(mean_inv_gammas))
191 sum_log_gamma = torch.sum(torch.tensor(mean_log_gammas))
192 def equations(params):
193 omega, chi, phi = params
194 z = np.sqrt(chi * phi)
195 K_omega = kv(omega, z)
196 d_logK = self.d_log_bessel_k(omega, z)
197 R_omega = kv(omega + 1, z) / K_omega if K_omega != 0 else 0.0
198 eq1 = Q * np.log(np.sqrt(phi / chi)) - Q * d_logK * sum_log_gamma
199 eq2 = (Q * omega) / chi - (Q / 2) * np.sqrt(phi / chi) * R_omega + 0.5 * sum_inv_gamma
200 eq3 = Q * np.sqrt(chi/ phi) * R_omega - sum_gamma
201 return [eq1, eq2, eq3]
202 omega0 = float(self.omega_prior.cpu().numpy())
203 chi0 = float(self.chi_prior.cpu().numpy())
204 phi0 = float(self.phi_prior.cpu().numpy())
205 sol = root(equations, [omega0, chi0, phi0], method='hybr')
206 if sol.success:
207 self.omega_prior = torch.tensor(sol.x[0], device=self.device)
208 self.chi_prior = torch.tensor(sol.x[1], device=self.device)
209 self.phi_prior = torch.tensor(sol.x[2], device=self.device)
211 def compute_moments_W(self, M_W: Tensor, Omega_W: Tensor, S_W: Tensor) -> Tensor:
212 """
213 Вычисляет момент ⟨W W^T⟩ для матрицы W.
214 Аргументы:
215 M_W (Tensor): Матрица средних значений W.
216 Omega_W (Tensor): Ковариационная матрица по строкам W.
217 S_W (Tensor): Ковариационная матрица по столбцам W.
218 Возвращает:
219 Tensor: Момент ⟨W W^T⟩.
220 """
221 # Момент: E[W W^T] = M_W M_W^T + tr(S_W) * Omega_W
222 return M_W @ M_W.t() + torch.trace(S_W) * Omega_W
224 def compute_moments_Z(self, M_Z: Tensor, Omega_Z: Tensor, S_Z: Tensor) -> Tensor:
225 """
226 Вычисляет момент ⟨Z Z^T⟩ для матрицы Z.
227 Аргументы:
228 M_Z (Tensor): Матрица средних значений Z.
229 Omega_Z (Tensor): Ковариационная матрица по строкам Z.
230 S_Z (Tensor): Ковариационная матрица по столбцам Z.
231 Возвращает:
232 Tensor: Момент ⟨Z Z^T⟩.
233 """
234 # Момент: E[Z Z^T] = M_Z M_Z^T + tr(S_Z) * Omega_Z
235 return M_Z @ M_Z.t() + torch.trace(S_Z) * Omega_Z
237 def compute_moments_VVT(self, M_V: Tensor, Omega_V: Tensor, S_V: Tensor) -> Tensor:
238 """
239 Вычисляет момент ⟨V V^T⟩ для матрицы V.
240 Аргументы:
241 M_V (Tensor): Матрица средних значений V.
242 Omega_V (Tensor): Ковариационная матрица по строкам V.
243 S_V (Tensor): Ковариационная матрица по столбцам V.
244 Возвращает:
245 Tensor: Момент ⟨V V^T⟩.
246 """
247 # Момент: E[V V^T] = M_V M_V^T + tr(S_V) * Omega_V
248 return M_V @ M_V.t() + torch.trace(Omega_V) * S_V
249 def compute_moments_VTV(self, M_V: Tensor, Omega_V: Tensor, S_V: Tensor) -> Tensor:
250 """
251 Вычисляет момент ⟨V^T V⟩ для матрицы V.
252 Аргументы:
253 M_V (Tensor): Матрица средних значений V.
254 Omega_V (Tensor): Ковариационная матрица по строкам V.
255 S_V (Tensor): Ковариационная матрица по столбцам V.
256 Возвращает:
257 Tensor: Момент ⟨V^T V⟩.
258 """
259 # Момент: E[V V^T] = M_V M_V^T + tr(S_V) * Omega_V
260 return M_V.T @ M_V + torch.trace(S_V) * Omega_V
262 def e_step(self, X: Tensor, Y: Tensor) -> Dict[str, Any]:
263 """
264 E-шаг: координирует обновление всех параметров постериорных распределений и моментов.
265 Аргументы:
266 X (Tensor): матрица признаков (D, N)
267 Y (Tensor): матрица откликов (P, N)
268 Возвращает: dict[str, Any] — словарь с основными статистиками и параметрами для M-шагa.
269 """
270 group_moments = self._compute_group_moments()
271 self._update_posterior_matrices(X, Y, group_moments)
272 self._update_posterior_wishart(group_moments)
273 self._update_posterior_gig(group_moments)
274 return {
275 'mean_gammas': group_moments['mean_gammas'],
276 'mean_inv_gammas': group_moments['mean_inv_gammas'],
277 'mean_log_gammas': group_moments['mean_log_gammas'],
278 "M_W": self.M_W, "Omega_W": self.Omega_W, "S_W": self.S_W,
279 "M_Z": self.M_Z, "Omega_Z": self.Omega_Z, "S_Z": self.S_Z,
280 "M_V": self.M_V, "Omega_V": self.Omega_V, "S_V": self.S_V
281 }
283 def _compute_group_moments(self) -> Dict[str, Any]:
284 """
285 Вычисляет моменты (средние значения) по всем группам, а также блочные матрицы Gamma и Omega_inv.
286 Возвращает: dict[str, Any] — словарь с этими величинами.
287 """
288 G = len(self.group_indices)
289 D = self.D
290 mean_gammas, mean_inv_gammas, mean_log_gammas = [], [], []
291 Gamma = torch.zeros(D, D, device=self.device)
292 Omega_inv = torch.zeros(D, D, device=self.device)
293 for g, idxs in enumerate(self.group_indices):
294 omega = float(self.omega_post[g].cpu().numpy())
295 chi = float(self.chi_post[g].cpu().numpy())
296 phi = float(self.phi_post[g].cpu().numpy())
297 mg = self.mean_gig(omega, chi, phi)
298 mig = self.mean_inv_gig(omega, chi, phi)
299 mlg = self.mean_log_gig(omega, chi, phi)
300 mean_gammas.append(mg)
301 mean_inv_gammas.append(mig)
302 mean_log_gammas.append(mlg)
303 Gamma[idxs, :][:, idxs] = mg * torch.eye(len(idxs), device=self.device)
304 Omega_inv[idxs, :][:, idxs] = self.Omega_inv_g[g]
305 return {
306 'mean_gammas': mean_gammas,
307 'mean_inv_gammas': mean_inv_gammas,
308 'mean_log_gammas': mean_log_gammas,
309 'Gamma': Gamma,
310 'Omega_inv': Omega_inv
311 }
313 def _update_posterior_matrices(self, X: Tensor, Y: Tensor, group_moments: Dict[str, Any]) -> None:
314 """
315 Обновляет параметры постериорных матрично-нормальных распределений W, Z, V на основе текущих моментов и данных.
316 Аргументы:
317 X (Tensor): матрица признаков
318 Y (Tensor): матрица откликов
319 group_moments (dict): словарь с моментами и блочными матрицами
320 Возвращает: None
321 """
322 Gamma = group_moments['Gamma']
323 Omega_inv = group_moments['Omega_inv']
324 tau = self.tau
325 sigma2 = self.sigma2
326 # Обновление W
327 Omega_W_inv = (1.0 / tau) * Omega_inv @ Gamma + (1.0 / sigma2) * (X @ X.t())
328 self.Omega_W = torch.linalg.inv(Omega_W_inv)
329 Z = torch.cat(self.M_Z, dim=1) # [K, D]
330 self.M_W = ((1.0 / tau) * self.M_V @ Z @ Omega_inv @ Gamma + (1.0 / sigma2) * Y @ X.t()) @ self.Omega_W
331 self.S_W = torch.eye(self.P, device=self.device)
332 # Обновление Z
333 self.M_Z, self.Omega_Z, self.S_Z = [], [], []
334 for g, idxs in enumerate(self.group_indices):
335 Dg = len(idxs)
336 Wg = self.M_W[:, idxs] # [P, Dg]
337 moment_V = self.compute_moments_VTV(self.M_V, self.Omega_V, self.S_V)
338 S_Zi = torch.linalg.inv((1.0 / tau) * moment_V + torch.eye(self.K, device=self.device))
339 M_Zi = (1.0 / tau) * S_Zi @ self.M_V.t() @ Wg
340 Omega_Zi = (1.0 / group_moments['mean_gammas'][g]) * torch.linalg.inv(self.Omega_inv_g[g])
341 self.M_Z.append(M_Zi)
342 self.Omega_Z.append(Omega_Zi)
343 self.S_Z.append(S_Zi)
344 # Обновление V
345 Omega_V_inv = torch.zeros(self.K, self.K, device=self.device)
346 for g, idxs in enumerate(self.group_indices):
347 Omega_i_inv = self.Omega_inv_g[g]
348 M_Zi = self.M_Z[g]
349 S_Zi = self.S_Z[g]
350 moment_Z = Omega_i_inv * torch.trace(S_Zi) + M_Zi @ Omega_i_inv @ M_Zi.t()
351 Omega_V_inv += group_moments['mean_gammas'][g] * moment_Z
352 Omega_V_inv += torch.eye(self.K, device=self.device)
353 self.Omega_V = torch.linalg.inv(Omega_V_inv)
354 self.M_V = self.M_W @ Omega_inv @ Gamma @ Z.t() @ self.Omega_V
355 self.S_V = tau * torch.eye(self.P, device=self.device)
357 def _update_posterior_wishart(self, group_moments: Dict[str, Any]) -> None:
358 """
359 Обновляет параметры постериорного распределения Wishart (Lambda, nu, Omega_inv_g) для каждой группы.
360 Аргументы:
361 group_moments (dict): словарь с моментами и блочными матрицами
362 Возвращает: None
363 """
364 self.Lambda, self.nu_post = [], []
365 for g, idxs in enumerate(self.group_indices):
366 Dg = len(idxs)
367 Wg = self.M_W[:, idxs] # [P, Dg]
368 Zg = self.M_Z[g] # [K, Dg]
369 resid = Wg - self.M_V @ Zg.T # [P, Dg]
370 Omega_W_g = self.Omega_W[idxs, :][:, idxs]
371 moment_resid_mean_part = resid.T @ resid
372 Omega_i_inv = self.Omega_inv_g[g]
373 D_W = torch.trace(self.S_W) * Omega_W_g
374 moment_resid_disp_part = D_W + self.Omega_Z[g] * torch.sum(self.S_Z[g] * self.Omega_V) * torch.trace(self.S_V) + self.M_Z[g].T @ self.Omega_V @self.M_Z[g] * torch.trace(self.S_V)+\
375 self.Omega_Z[g] * torch.sum(self.S_Z[g] * (self.M_Z[g].T @ self.M_Z[g]))
376 moment_resid = moment_resid_mean_part + moment_resid_disp_part
377 moment_Z = self.Omega_Z[g] * torch.sum(self.S_Z[g] @ Omega_i_inv) + self.M_Z[g] @ Omega_i_inv @ self.M_Z[g].T
378 Lambda_i = (1.0 / self.tau) * group_moments['mean_gammas'][g] * moment_resid + group_moments['mean_gammas'][g] * moment_Z + torch.eye(Dg, device=self.device)
379 self.Lambda.append(Lambda_i)
380 self.nu_post.append(float(self.nu_prior) + self.P + self.K)
381 self.Omega_inv_g = [(Dg + self.nu_post[g] - 1) * torch.linalg.inv(self.Lambda[g]) for g, idxs in enumerate(self.group_indices)]
383 def _update_posterior_gig(self, group_moments: Dict[str, Any]) -> None:
384 """
385 Обновляет параметры постериорного GIG (omega_post, chi_post, phi_post) для каждой группы.
386 Аргументы:
387 group_moments (dict): словарь с моментами и блочными матрицами
388 Возвращает: None
389 """
390 for g, idxs in enumerate(self.group_indices):
391 Dg = len(idxs)
392 # omega_post
393 self.omega_post[g] = self.omega_prior + 0.5 * (self.P + self.K) * Dg
394 # chi_post
395 self.chi_post[g] = self.chi_prior
396 # phi_post
397 #Вспомогательные переменные
398 Omega_i_inv = self.Omega_inv_g[g]
399 Wg = self.M_W[:, idxs] # [P, Dg]
400 Zg = self.M_Z[g] # [K, Dg]
401 resid = Wg - self.M_V @ Zg
402 Omega_W_g = self.Omega_W[idxs, :][:, idxs]
403 Omega_inv_i = self.Omega_inv_g[g]
404 #Подсчеты моментов
405 moment_resid_mean_part = resid @ Omega_i_inv @ resid.t()
406 moment_resid_W_disp_part = self.S_W * torch.sum(Omega_W_g * Omega_inv_i)
407 moment_resid_VZ_disp_part = self.S_V * torch.sum(self.Omega_V * self.S_Z[g]) * torch.sum(self.Omega_Z[g]* Omega_i_inv) + self.M_V @ self.S_Z[g] @self.M_V.T * torch.sum(self.Omega_Z[g]* Omega_i_inv)+\
408 self.S_V * torch.sum(self.Omega_V * (self.M_Z[g] @ Omega_inv_i @self.M_Z[g].T))
409 moment_resid = moment_resid_mean_part + moment_resid_W_disp_part + moment_resid_VZ_disp_part
410 tr_resid = torch.trace(moment_resid)
411 # Момент: <Z_i Omega_i^{-1} Z_i^T>
412 moment_Z = Zg @ Omega_i_inv @ Zg.t() + torch.trace(self.S_Z[g]) * Omega_i_inv
413 tr_Z = torch.trace(moment_Z)
414 self.phi_post[g] = self.phi_prior + (1.0 / self.tau) * tr_resid + tr_Z
415 def m_step(self, post: Dict[str, Any]) -> None:
416 """
417 M-шаг: обновляет гиперпараметры модели на основе результатов E-шагa.
419 Аргументы:
420 post (Dict[str, Any]): Словарь с результатами E-шагa, включая моменты и параметры постериорного распределения.
421 """
422 # Теперь только обновление общего прайора
423 self.update_gig_prior(post['mean_gammas'], post['mean_inv_gammas'], post['mean_log_gammas'])
424 # Обновление tau (дисперсия шума)
425 # ... (см. 4.2, зависит от задачи)
426 def fit(self, X: Tensor, Y: Tensor, num_iter: int = 10) -> None:
427 """
428 Обучает модель с использованием EM-алгоритма.
430 Аргументы:
431 X (Tensor): Матрица признаков (N, D).
432 Y (Tensor): Матрица откликов (N, P).
433 num_iter (int): Количество итераций EM-алгоритма.
434 """
435 X = X.T
436 for _ in range(num_iter):
437 post = self.e_step(X, Y)
438 self.m_step(post)
439 self._set_flat_params(post["M_W"])
441 def predict(self, X: Tensor) -> Tensor:
442 """
443 Выполняет предсказание на основе обученной модели.
445 Аргументы:
446 X (Tensor): Матрица признаков (N, D).
448 Возвращает:
449 Tensor: Предсказания модели (N, P).
450 """
451 self.model.eval()
452 with torch.no_grad():
453 return self.model(X).T