Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR4: Add deep_mmd_loss files #170

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open

PR4: Add deep_mmd_loss files #170

wants to merge 20 commits into from

Conversation

sanaAyrml
Copy link
Collaborator

PR Type

[Feature]

Short Description

This is a tentative implementation for deep mmd loss.

Tests Added

No tests added yet.

@sanaAyrml sanaAyrml requested a review from emersodb June 7, 2024 06:38
EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and additional losses
indexed by name.
"""
for layer in self.flatten_feature_extraction_layers.keys():
Copy link
Collaborator

@emersodb emersodb Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to be indexing into self.deep_mmd_losses anyway, could we simply do

for layer_loss_module in self.deep_mmd_losses.values():
    layer_loss_module.training = False

For Ditto, we do this process in validate and train_by_steps/train_by_epochs for the global model, maybe we can just do this there?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still worth overriding compute_evaluation_loss and compute_training_loss and asserting that all layer_loss_module.training == False or vice versa though to be safe 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also might be missing this, but I don't see where we set layer_loss_module.training to True in the client. Based on the loss code, this would mean that we won't run training of the deep kernels after the first server round, which I think we want to keep doing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! The True setting was indeed missing, so I added it to the update_before_train function. Following your suggestion, I moved the False setting to the validate function. I kept the assertions in both compute_evaluation_loss and compute_training_loss functions for consistency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just iterate through the self.deep_mmd_losses values and do assertions I think?

for layer_loss_module in self.deep_mmd_losses.values():
    assert not layer_loss_module.training

list(self.featurizer.parameters()) + [self.epsilonOPT] + [self.sigmaOPT] + [self.sigma0OPT], lr=self.lr
)

def Pdist2(self, x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe expand this to pairwise_distiance_squared?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we don't leverage the fact that y can be none to get the distances of x with itself. Maybe we just drop that option and require y to be passed to simplify this function.

# Compute output of deep network
model_output = self.featurizer(features)
# Compute epsilon, sigma and sigma_0
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename epsilon and note that it is the epsilon in $\kappa_w(x, y)$ in the paper

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like we did this? I think both the rename and comment are worthwhile

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I missed this

# Compute epsilon, sigma and sigma_0
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT))
sigma = self.sigmaOPT**2
sigma0_u = self.sigma0OPT**2
Copy link
Collaborator

@emersodb emersodb Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the implementation of MMDu I would suggest renaming sigma0 to sigma_phi, sigma0OPT to sigma_phi_opt and sigma0_u to sigma_phi (since there doesn't seem to be any reason to have _u in there anyway. Similarly, anything that is sigma or sigmaOPT can be sigma_q or sigma_q_opt to match the notation of the paper.

@sanaAyrml sanaAyrml changed the base branch from add_mkmmd_loss to sa_update_mkmmd_loss September 16, 2024 15:48
@sanaAyrml sanaAyrml changed the base branch from sa_update_mkmmd_loss to sa_add_cifar10_experiments September 17, 2024 18:54
@sanaAyrml sanaAyrml changed the title PR3: Add deep_mmd_loss files PR4: Add deep_mmd_loss files Sep 17, 2024
Base automatically changed from sa_add_cifar10_experiments to main October 7, 2024 16:52
Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice changes. Just added a few small comments and reminders of a few pieces you might have overlooked in my comments. Very close to ready to go!

EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and additional losses
indexed by name.
"""
for layer in self.flatten_feature_extraction_layers.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just iterate through the self.deep_mmd_losses values and do assertions I think?

for layer_loss_module in self.deep_mmd_losses.values():
    assert not layer_loss_module.training

for layer, layer_deep_mmd_loss in self.deep_mmd_losses.items():
deep_mmd_loss = layer_deep_mmd_loss(features[layer], features[" ".join(["init_global", layer])])
additional_losses["_".join(["deep_mmd_loss", layer])] = deep_mmd_loss
total_deep_mmd_loss += deep_mmd_loss
total_loss += self.deep_mmd_loss_weight * total_deep_mmd_loss
additional_losses["deep_mmd_loss_total"] = total_deep_mmd_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be safe, maybe we can clone total_deep_mmd_loss here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added that but I am checking bunch of other ditto versions and we don't have any where. I am wondering whether I should update them or not.

fl4health/losses/deep_mmd_loss.py Outdated Show resolved Hide resolved
fl4health/losses/deep_mmd_loss.py Show resolved Hide resolved
# Compute output of deep network
model_output = self.featurizer(features)
# Compute epsilon, sigma and sigma_0
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like we did this? I think both the rename and comment are worthwhile

fl4health/losses/deep_mmd_loss.py Outdated Show resolved Hide resolved
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT))
sigma = self.sigmaOPT**2
sigma0_u = self.sigma0OPT**2
# Compute Compute J (STAT_u)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd include the notation mention in your comment as well if you're alright with it (\hat{J}_{\lambda})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants