dacbench.envs.sgd
Module Contents
Classes
Enum where members are also (and must be) ints |
|
Environment to control the learning rate of adam |
Functions
|
- dacbench.envs.sgd.reward_range(frange)
- class dacbench.envs.sgd.Reward
Bases:
enum.IntEnumEnum where members are also (and must be) ints
- TrainingLoss
- ValidationLoss
- LogTrainingLoss
- LogValidationLoss
- DiffTraining
- DiffValidation
- LogDiffTraining
- LogDiffValidation
- FullTraining
- __call__(self, f)
- class dacbench.envs.sgd.SGDEnv(config)
Bases:
dacbench.AbstractEnvEnvironment to control the learning rate of adam
- val_model
Samuel Mueller (PhD student in our group) also uses backpack and has ran into a similar memory leak. He solved it calling this custom made RECURSIVE memory_cleanup function: # from backpack import memory_cleanup # def recursive_backpack_memory_cleanup(module: torch.nn.Module): # memory_cleanup(module) # for m in module.modules(): # memory_cleanup(m) (calling this after computing the training loss/gradients and after validation loss should suffice)
- Type
TODO
- get_reward(self)
- get_training_reward(self)
- get_validation_reward(self)
- get_log_training_reward(self)
- get_log_validation_reward(self)
- get_log_diff_training_reward(self)
- get_log_diff_validation_reward(self)
- get_diff_training_reward(self)
- get_diff_validation_reward(self)
- get_full_training_reward(self)
- get_full_training_loss(self)
- property crash(self)
- seed(self, seed=None, seed_action_space=False)
Set rng seed
- Parameters
seed – seed for rng
seed_action_space (bool, default False) – if to seed the action space as well
- step(self, action)
Execute environment step
- Parameters
action (list) – action to execute
- Returns
state, reward, done, info
- Return type
np.array, float, bool, dict
- _architecture_constructor(self, arch_str)
- reset(self)
Reset environment
- Returns
Environment state
- Return type
np.array
- set_writer(self, writer)
- close(self)
No additional cleanup necessary
- Returns
Cleanup flag
- Return type
bool
- render(self, mode: str = 'human')
Render env in human mode
- Parameters
mode (str) – Execution mode
- get_default_state(self, _)
Gather state description
- Returns
Environment state
- Return type
dict
- _train_batch_(self)
- train_network(self)
- _get_full_training_loss(self, loader)
- property current_validation_loss(self)
- _get_validation_loss_(self)
- _get_validation_loss(self)
- _get_gradients(self)
- _get_momentum(self, gradients)
- get_adam_direction(self)
- get_rmsprop_direction(self)
- get_momentum_direction(self)
- _get_loss_features(self)
- _get_predictive_change_features(self, lr)
- _get_alignment(self)
- generate_instance_file(self, file_name, mode='test', n=100)