ssl_tools.analysis.latent_analysis
Classes
Module Contents
- class ssl_tools.analysis.latent_analysis.LatentAnalysis(layers, sklearn_cls, output_name_suffix='transformed', **sklearn_kwargs)
- Parameters:
layers (List[str])
output_name_suffix (str)
- __call__(trainer, model, data_module)
- Parameters:
trainer (lightning.Trainer)
model (lightning.LightningModule)
data_module (lightning.LightningDataModule)
- class ssl_tools.analysis.latent_analysis.LayerOutputSaverHook
- _forward_hook(module, inputs, outputs, layer_name)
- Parameters:
layer_name (str)
- attach_hooks(model, layer_names)
- Parameters:
model (lightning.LightningModule)
layer_names (List[str])
- outputs_from_layer(layer_name, concat=True)
- Parameters:
layer_name (str)
concat (bool)
- remove_hooks()
- run_model_with_hooks(model, layer_names)
- Parameters:
model (lightning.LightningModule)
layer_names (List[str])