The SGD Deep Learning Benchmark
Built on top of PyTorch, this benchmark allows for dynamic learning rate control in deep learning. At each step until the cutoff, i.e. after each epoch, the DAC controller provides a new learning rate value to the network. Success is measured by decreasing validation loss.
This is a very flexible benchmark, as in principle all kinds of classification datasets and PyTorch compatible architectures can be included in training. The underlying task is not easy, however, so we recommend starting with small networks and datasets and building up to harder tasks.
- class dacbench.benchmarks.sgd_benchmark.SGDBenchmark(config_path=None, config=None)
Bases:
AbstractBenchmark
Benchmark with default configuration & relevant functions for SGD
- get_benchmark(instance_set_path=None, seed=0)
Get benchmark from the LTO paper
- Parameters
seed (int) – Environment seed
- Returns
env – SGD environment
- Return type
- get_environment()
Return SGDEnv env with current configuration
- Returns
SGD environment
- Return type
- read_instance_set(test=False)
Read path of instances from config into list
- class dacbench.envs.sgd.Reward(value)
Bases:
IntEnum
An enumeration.
- class dacbench.envs.sgd.SGDEnv(config)
Bases:
AbstractEnv
Environment to control the learning rate of adam
- close()
No additional cleanup necessary
- Returns
Cleanup flag
- Return type
bool
- get_default_state(_)
Gather state description
- Returns
Environment state
- Return type
dict
- render(mode: str = 'human')
Render env in human mode
- Parameters
mode (str) – Execution mode
- reset()
Reset environment
- Returns
Environment state
- Return type
np.array
- seed(seed=None, seed_action_space=False)
Set rng seed
- Parameters
seed – seed for rng
seed_action_space (bool, default False) – if to seed the action space as well
- step(action)
Execute environment step
- Parameters
action (list) – action to execute
- Returns
state, reward, done, info
- Return type
np.array, float, bool, dict
- val_model
Samuel Mueller (PhD student in our group) also uses backpack and has ran into a similar memory leak. He solved it calling this custom made RECURSIVE memory_cleanup function: # from backpack import memory_cleanup # def recursive_backpack_memory_cleanup(module: torch.nn.Module): # memory_cleanup(module) # for m in module.modules(): # memory_cleanup(m) (calling this after computing the training loss/gradients and after validation loss should suffice)
- Type
TODO