velora.utils¶
Documentation
Generic utility methods usable in any experiment.
set_device(device='auto')
¶
Sets the PyTorch
device dynamically.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device
|
str
|
the name of the device to perform computations on. When
|
'auto'
|
Returns:
Name | Type | Description |
---|---|---|
device |
torch.device
|
the |
Source code in velora/utils/core.py
Python | |
---|---|
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
set_seed(value=None)
¶
Sets the random seed for Python
, PyTorch
and NumPy
.
When None
will create a new one automatically.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value
|
int
|
the seed value |
None
|
Returns:
Name | Type | Description |
---|---|---|
seed |
int
|
the used seed value |
Source code in velora/utils/core.py
Python | |
---|---|
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
|
active_parameters(model)
¶
Calculates the active number of parameters used in a PyTorch nn.Module
.
Filters out parameters that are 0
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
nn.Module
|
a PyTorch module with parameters |
required |
Returns:
Name | Type | Description |
---|---|---|
count |
int
|
the total active number of parameters. |
Source code in velora/utils/torch.py
Python | |
---|---|
93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
hard_update(source, target)
¶
Performs a hard parameter update between two PyTorch Networks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source
|
nn.Module
|
the source network |
required |
target
|
nn.Module
|
the target network |
required |
Source code in velora/utils/torch.py
Python | |
---|---|
67 68 69 70 71 72 73 74 75 76 |
|
soft_update(source, target, *, tau=0.005)
¶
Performs a soft parameter update between two PyTorch Networks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source
|
nn.Module
|
the source network |
required |
target
|
nn.Module
|
the target network |
required |
tau
|
float
|
the soft update factor used to slowly update the target network |
0.005
|
Source code in velora/utils/torch.py
Python | |
---|---|
53 54 55 56 57 58 59 60 61 62 63 64 |
|
stack_tensor(items, *, dtype=torch.float32, device=None)
¶
Stacks a list of tensors together, then:
- Converts it to a specific
dtype
- Loads it onto
device
Parameters:
Name | Type | Description | Default |
---|---|---|---|
items
|
List[torch.Tensor]
|
a list of torch.Tensors full of items |
required |
dtype
|
torch.dtype
|
the data type for the tensor |
torch.float32
|
device
|
torch.device
|
the device to perform computations on |
None
|
Returns:
Name | Type | Description |
---|---|---|
tensor |
torch.Tensor
|
the updated |
Source code in velora/utils/torch.py
Python | |
---|---|
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
|
summary(module)
¶
Outputs a summary of a module and all it's sub-modules as a dictionary.
Returns:
Name | Type | Description |
---|---|---|
summary |
Dict[str, str]
|
key-value pairs for the network layout. |
Source code in velora/utils/torch.py
Python | |
---|---|
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
|
to_tensor(items, *, dtype=torch.float32, device=None)
¶
Converts a list of items to a Tensor, then:
- Converts it to a specific
dtype
- Loads it onto
device
Parameters:
Name | Type | Description | Default |
---|---|---|---|
items
|
List[Any]
|
a list of items of any type |
required |
dtype
|
torch.dtype
|
the data type for the tensor |
torch.float32
|
device
|
torch.device
|
the device to perform computations on |
None
|
Returns:
Name | Type | Description |
---|---|---|
tensor |
torch.Tensor
|
the updated |
Source code in velora/utils/torch.py
Python | |
---|---|
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
|
total_parameters(model)
¶
Calculates the total number of parameters used in a PyTorch nn.Module
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
nn.Module
|
a PyTorch module with parameters |
required |
Returns:
Name | Type | Description |
---|---|---|
count |
int
|
the total number of parameters. |
Source code in velora/utils/torch.py
Python | |
---|---|
79 80 81 82 83 84 85 86 87 88 89 90 |
|