ssl_tools.benchmarks.main_mix_style
Classes
Functions
|
|
|
|
|
3x3 convolution with padding |
|
3x3 convolution with padding |
|
|
|
|
|
|
|
Module Contents
- class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_1D(input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
Bases:
SimpleClassificationNet2
- Parameters:
input_shape (Tuple[int, int, int])
num_classes (int)
learning_rate (float)
- _calculate_fc_input_features(backbone, input_shape)
Run a single forward pass with a random input to get the number of features after the convolutional layers.
Parameters
- backbonetorch.nn.Module
The backbone of the network
- input_shapeTuple[int, int, int]
The input shape of the network.
Returns
- int
The number of features after the convolutional layers.
- Parameters:
backbone (torch.nn.Module)
input_shape (Tuple[int, int, int])
- Return type:
int
- _create_backbone(input_shape)
- Parameters:
input_shape (Tuple[int, int])
- Return type:
torch.nn.Module
- _create_fc(input_features, num_classes)
- Parameters:
input_features (int)
num_classes (int)
- Return type:
torch.nn.Module
- class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_1D_Backbone(input_channels=1)
Bases:
torch.nn.Module
- Parameters:
input_channels (int)
- forward(x)
- class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_2D(pad_at=(3,), input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
Bases:
SimpleClassificationNet2
- Parameters:
pad_at (List[int])
input_shape (Tuple[int, int, int])
num_classes (int)
learning_rate (float)
- _calculate_fc_input_features(backbone, input_shape)
Run a single forward pass with a random input to get the number of features after the convolutional layers.
Parameters
- backbonetorch.nn.Module
The backbone of the network
- input_shapeTuple[int, int, int]
The input shape of the network.
Returns
- int
The number of features after the convolutional layers.
- Parameters:
backbone (torch.nn.Module)
input_shape (Tuple[int, int, int])
- Return type:
int
- _create_backbone(input_shape)
- Parameters:
input_shape (Tuple[int, int])
- Return type:
torch.nn.Module
- _create_fc(input_features, num_classes)
- Parameters:
input_features (int)
num_classes (int)
- Return type:
torch.nn.Module
- class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_2D_Backbone(pad_at, in_channels=1)
Bases:
torch.nn.Module
- Parameters:
pad_at (int)
in_channels (int)
- forward(x)
- class ssl_tools.benchmarks.main_mix_style.ConvolutionalBlock(in_channels, activation_cls=None)
Bases:
torch.nn.Module
- Parameters:
in_channels (int)
activation_cls (torch.nn.Module)
- forward(x)
- class ssl_tools.benchmarks.main_mix_style.ExperimentArgs
- data_cls: Any
- mix: bool = True
- model_args: Dict[str, Any]
- model_cls: Any
- seed: int = 42
- test_data_args: Dict[str, Any]
- train_data_args: Dict[str, Any]
- trainer_args: Dict[str, Any]
- trainer_cls: Any
- class ssl_tools.benchmarks.main_mix_style.ResNet1DBase(resnet_block_cls=ResNetBlock, activation_cls=torch.nn.ReLU, input_shape=(6, 60), num_classes=6, num_residual_blocks=5, reduction_ratio=2, learning_rate=0.001)
Bases:
SimpleClassificationNet2
- Parameters:
resnet_block_cls (type)
activation_cls (type)
input_shape (Tuple[int, int])
num_classes (int)
num_residual_blocks (int)
learning_rate (float)
- _calculate_fc_input_features(backbone, input_shape)
Run a single forward pass with a random input to get the number of features after the convolutional layers.
Parameters
- backbonetorch.nn.Module
The backbone of the network
- input_shapeTuple[int, int, int]
The input shape of the network.
Returns
- int
The number of features after the convolutional layers.
- Parameters:
backbone (torch.nn.Module)
input_shape (Tuple[int, int, int])
- Return type:
int
- class ssl_tools.benchmarks.main_mix_style.ResNet1D_8(*args, **kwargs)
Bases:
ResNet1DBase
- class ssl_tools.benchmarks.main_mix_style.ResNetBlock(in_channels=64, activation_cls=torch.nn.ReLU, mix_style_factor=False)
Bases:
torch.nn.Module
- Parameters:
in_channels (int)
activation_cls (torch.nn.Module)
- forward(x)
- class ssl_tools.benchmarks.main_mix_style.ResNetSE1D_5(*args, **kwargs)
Bases:
ResNet1DBase
- class ssl_tools.benchmarks.main_mix_style.ResNetSE1D_8(*args, **kwargs)
Bases:
ResNet1DBase
- class ssl_tools.benchmarks.main_mix_style.ResNetSEBlock(*args, **kwargs)
Bases:
ResNetBlock
- class ssl_tools.benchmarks.main_mix_style.SimpleClassificationNet2(backbone, fc, learning_rate=0.001, flatten=True, loss_fn=None, train_metrics=None, val_metrics=None, test_metrics=None)
Bases:
ssl_tools.models.nets.simple.SimpleClassificationNet
- Parameters:
backbone (torch.nn.Module)
fc (torch.nn.Module)
learning_rate (float)
flatten (bool)
loss_fn (torch.nn.Module)
train_metrics (Dict[str, torch.Tensor])
val_metrics (Dict[str, torch.Tensor])
test_metrics (Dict[str, torch.Tensor])
- single_step(batch, batch_idx, step_name)
- Parameters:
batch (torch.Tensor)
batch_idx (int)
step_name (str)
- class ssl_tools.benchmarks.main_mix_style.SqueezeAndExcitation1D(in_channels, reduction_ratio=2)
Bases:
torch.nn.Module
- Parameters:
in_channels (int)
reduction_ratio (int)
- forward(input_tensor)
- class ssl_tools.benchmarks.main_mix_style._ResNet1D(input_shape, residual_block_cls=ResNetBlock, activation_cls=torch.nn.ReLU, num_residual_blocks=5, reduction_ratio=2)
Bases:
torch.nn.Module
- Parameters:
input_shape (Tuple[int, int])
activation_cls (torch.nn.Module)
num_residual_blocks (int)
- forward(x)
- ssl_tools.benchmarks.main_mix_style._run_experiment_wrapper(experiment_args)
- Parameters:
experiment_args (ExperimentArgs)
- ssl_tools.benchmarks.main_mix_style.cli_main(experiment)
- Parameters:
experiment (ExperimentArgs)
- ssl_tools.benchmarks.main_mix_style.conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)
3x3 convolution with padding
- Parameters:
in_planes (int)
out_planes (int)
stride (int)
groups (int)
dilation (int)
- Return type:
torch.nn.Conv2d
- ssl_tools.benchmarks.main_mix_style.conv3x3_dynamic(in_planes, out_planes, stride=1, attention_in_channels=None)
3x3 convolution with padding
- Parameters:
in_planes (int)
out_planes (int)
stride (int)
attention_in_channels (int)
- Return type:
dassl.modeling.ops.Conv2dDynamic
- ssl_tools.benchmarks.main_mix_style.main_loo()
- ssl_tools.benchmarks.main_mix_style.pretty_print_experiment_args(args, indent=4)
- Parameters:
args (ExperimentArgs)
indent (int)
- Return type:
str
- ssl_tools.benchmarks.main_mix_style.run_serial(experiments)
- Parameters:
experiments (List[ExperimentArgs])
- ssl_tools.benchmarks.main_mix_style.run_using_ray(experiments, ray_address=None)
- Parameters:
experiments (List[ExperimentArgs])
ray_address (str)