ssl_tools.analysis.latent_analysis

Classes

LatentAnalysis

LayerOutputSaverHook

Module Contents

class ssl_tools.analysis.latent_analysis.LatentAnalysis(layers, sklearn_cls, output_name_suffix='transformed', **sklearn_kwargs)
Parameters:
  • layers (List[str])

  • output_name_suffix (str)

__call__(trainer, model, data_module)
Parameters:
  • trainer (lightning.Trainer)

  • model (lightning.LightningModule)

  • data_module (lightning.LightningDataModule)

class ssl_tools.analysis.latent_analysis.LayerOutputSaverHook
_forward_hook(module, inputs, outputs, layer_name)
Parameters:

layer_name (str)

attach_hooks(model, layer_names)
Parameters:
  • model (lightning.LightningModule)

  • layer_names (List[str])

outputs_from_layer(layer_name, concat=True)
Parameters:
  • layer_name (str)

  • concat (bool)

remove_hooks()
run_model_with_hooks(model, layer_names)
Parameters:
  • model (lightning.LightningModule)

  • layer_names (List[str])