-
Notifications
You must be signed in to change notification settings - Fork 13
Closed
Description
Here is a loss function composed of multiple loss combinations from one spatial transcriptomics task of spatial domain identification. Can losses calculated by different functions be used for multi-task learning?
import torch
import torch.nn.functional as F
from torch.optim import Adam
+ from torchjd import mtl_backward
+ from torchjd.aggregation import UPGrad
optimizer = Adam(params, lr=0.1)
+ aggregator = UPGrad()
for ...
mse_loss = F.mse_loss(decoded, x)
bce_loss = F.binary_cross_entropy_with_logits(preds, labels)
KLD = -0.5 / n_nodes * torch.mean(torch.sum(1 + 2*logvar -mu.pow(2)-logvar.exp().pow(2), 1))
loss_contra = contrastive_loss(...)
optimizer.zero_grad()
- loss = mse_loss + bce_loss + KLD + loss_contra
- loss.backward()
+ mtl_backward(losses=[mse_loss, bce_loss, KLD, loss_contra], features=features, aggregator=aggregator)
optimizer.step()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels