velora.utils¶
Documentation
Generic utility methods usable in any experiment.
set_device(device='auto')
¶
Sets the PyTorch device dynamically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
str
|
the name of the device to perform computations on. When
|
'auto'
|
Returns:
| Name | Type | Description |
|---|---|---|
device |
torch.device
|
the |
Source code in velora/utils/core.py
| Python | |
|---|---|
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | |
set_seed(value=None)
¶
Sets the random seed for Python, PyTorch and NumPy.
When None will create a new one automatically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
int
|
the seed value |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
seed |
int
|
the used seed value |
Source code in velora/utils/core.py
| Python | |
|---|---|
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | |
active_parameters(model)
¶
Calculates the active number of parameters used in a PyTorch nn.Module.
Filters out parameters that are 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
nn.Module
|
a PyTorch module with parameters |
required |
Returns:
| Name | Type | Description |
|---|---|---|
count |
int
|
the total active number of parameters. |
Source code in velora/utils/torch.py
| Python | |
|---|---|
93 94 95 96 97 98 99 100 101 102 103 104 105 | |
hard_update(source, target)
¶
Performs a hard parameter update between two PyTorch Networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
source
|
nn.Module
|
the source network |
required |
target
|
nn.Module
|
the target network |
required |
Source code in velora/utils/torch.py
| Python | |
|---|---|
67 68 69 70 71 72 73 74 75 76 | |
soft_update(source, target, *, tau=0.005)
¶
Performs a soft parameter update between two PyTorch Networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
source
|
nn.Module
|
the source network |
required |
target
|
nn.Module
|
the target network |
required |
tau
|
float
|
the soft update factor used to slowly update the target network |
0.005
|
Source code in velora/utils/torch.py
| Python | |
|---|---|
53 54 55 56 57 58 59 60 61 62 63 64 | |
stack_tensor(items, *, dtype=torch.float32, device=None)
¶
Stacks a list of tensors together, then:
- Converts it to a specific
dtype - Loads it onto
device
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
List[torch.Tensor]
|
a list of torch.Tensors full of items |
required |
dtype
|
torch.dtype
|
the data type for the tensor |
torch.float32
|
device
|
torch.device
|
the device to perform computations on |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
tensor |
torch.Tensor
|
the updated |
Source code in velora/utils/torch.py
| Python | |
|---|---|
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | |
summary(module)
¶
Outputs a summary of a module and all it's sub-modules as a dictionary.
Returns:
| Name | Type | Description |
|---|---|---|
summary |
Dict[str, str]
|
key-value pairs for the network layout. |
Source code in velora/utils/torch.py
| Python | |
|---|---|
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | |
to_tensor(items, *, dtype=torch.float32, device=None)
¶
Converts a list of items to a Tensor, then:
- Converts it to a specific
dtype - Loads it onto
device
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
List[Any]
|
a list of items of any type |
required |
dtype
|
torch.dtype
|
the data type for the tensor |
torch.float32
|
device
|
torch.device
|
the device to perform computations on |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
tensor |
torch.Tensor
|
the updated |
Source code in velora/utils/torch.py
| Python | |
|---|---|
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | |
total_parameters(model)
¶
Calculates the total number of parameters used in a PyTorch nn.Module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
nn.Module
|
a PyTorch module with parameters |
required |
Returns:
| Name | Type | Description |
|---|---|---|
count |
int
|
the total number of parameters. |
Source code in velora/utils/torch.py
| Python | |
|---|---|
79 80 81 82 83 84 85 86 87 88 89 90 | |