dacbench.envs.sgd

Module Contents

Classes

Reward

Enum where members are also (and must be) ints

SGDEnv

Environment to control the learning rate of adam

Functions

reward_range(frange)

dacbench.envs.sgd.reward_range(frange)
class dacbench.envs.sgd.Reward

Bases: enum.IntEnum

Enum where members are also (and must be) ints

TrainingLoss
ValidationLoss
LogTrainingLoss
LogValidationLoss
DiffTraining
DiffValidation
LogDiffTraining
LogDiffValidation
FullTraining
__call__(self, f)
class dacbench.envs.sgd.SGDEnv(config)

Bases: dacbench.AbstractEnv

Environment to control the learning rate of adam

val_model

Samuel Mueller (PhD student in our group) also uses backpack and has ran into a similar memory leak. He solved it calling this custom made RECURSIVE memory_cleanup function: # from backpack import memory_cleanup # def recursive_backpack_memory_cleanup(module: torch.nn.Module): # memory_cleanup(module) # for m in module.modules(): # memory_cleanup(m) (calling this after computing the training loss/gradients and after validation loss should suffice)

Type

TODO

get_reward(self)
get_training_reward(self)
get_validation_reward(self)
get_log_training_reward(self)
get_log_validation_reward(self)
get_log_diff_training_reward(self)
get_log_diff_validation_reward(self)
get_diff_training_reward(self)
get_diff_validation_reward(self)
get_full_training_reward(self)
get_full_training_loss(self)
property crash(self)
seed(self, seed=None, seed_action_space=False)

Set rng seed

Parameters
  • seed – seed for rng

  • seed_action_space (bool, default False) – if to seed the action space as well

step(self, action)

Execute environment step

Parameters

action (list) – action to execute

Returns

state, reward, done, info

Return type

np.array, float, bool, dict

_architecture_constructor(self, arch_str)
reset(self)

Reset environment

Returns

Environment state

Return type

np.array

set_writer(self, writer)
close(self)

No additional cleanup necessary

Returns

Cleanup flag

Return type

bool

render(self, mode: str = 'human')

Render env in human mode

Parameters

mode (str) – Execution mode

get_default_state(self, _)

Gather state description

Returns

Environment state

Return type

dict

_train_batch_(self)
train_network(self)
_get_full_training_loss(self, loader)
property current_validation_loss(self)
_get_validation_loss_(self)
_get_validation_loss(self)
_get_gradients(self)
_get_momentum(self, gradients)
get_adam_direction(self)
get_rmsprop_direction(self)
get_momentum_direction(self)
_get_loss_features(self)
_get_predictive_change_features(self, lr)
_get_alignment(self)
generate_instance_file(self, file_name, mode='test', n=100)