Commit e06d8eef authored by lli's avatar lli
Browse files

Corrected the learning rate scheduler

parent 0a799f4c
......@@ -8,7 +8,7 @@ device = torch.device('cuda' if USE_CUDA else 'cpu')
# Define the policy net
class PolicyNetwork(nn.Module):
def __init__(self, n_state, n_action, n_hidden, lr):
def __init__(self, n_state, n_action, n_hidden, lr, schedule_step=100000, schedule_rate=0.5, lr_schedule=False):
'''
Initialize the policy neural network:
Use one hidden layer
......@@ -30,7 +30,11 @@ class PolicyNetwork(nn.Module):
)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.1)
self.lr_schedule = lr_schedule
if self.lr_schedule:
self.schedule_step = schedule_step
self.schedule_rate = schedule_rate
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.schedule_step, gamma=self.schedule_rate)
def predict(self, state):
# Compute the action probabilities of state s using the learning rate
......@@ -55,7 +59,8 @@ class PolicyNetwork(nn.Module):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.scheduler.step()
if self.lr_schedule:
self.scheduler.step()
def get_action(self, state):
probs = self.predict(state)
......
......@@ -14,4 +14,5 @@ n_state = len(env.state) # Number of input
n_action = env.action_space.n # Number of output
gamma = 1 # Discounted factor
policy_lr_schedule = False # Learning rate scheduler for policy net
......@@ -92,9 +92,11 @@ n_episode = args.n_episode # Number of training episodes
total_reward_episode = [0] * n_episode
num_no_capacity = []
accepted_orders = []
lr = []
start_time = timer()
if policy_lr_schedule:
lr = []
#############################
# Training
#############################s
......@@ -141,17 +143,21 @@ for episode in range(n_episode):
advantages = (returns - baseline_values)
value_net.update(states, returns)
lr.append(policy_net.optimizer.param_groups[0]['lr'])
if policy_lr_schedule:
lr.append(policy_net.optimizer.param_groups[0]['lr'])
# Update nn based on discounted rewards and log_probs
policy_net.update(advantages, log_probs)
print('Episode: {}, total reward: {}, number of penalties: {}, accepted orders: {}, learning rate: {}'.format(episode,
if policy_lr_schedule:
print('Episode: {}, total reward: {}, number of penalties: {}, accepted orders: {}, learning rate: {}'.format(episode,
total_reward_episode[
episode],
num_no_capacity[
episode],
accepted_orders[
episode], lr[episode]))
else:
print(f'Episode: {episode}, total reward: {total_reward_episode[episode]}, number of penalties: {num_no_capacity[episode]}, accepted orders: {accepted_orders[episode]}')
# print('Episode: {}, selected action: {}'.format(episode, selected_action))
break
......@@ -186,11 +192,12 @@ plt.legend(handles=[num_penalty, avg_penalty], loc='best')
plt.savefig(os.path.join(OUT_PATH, 'penalties.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
plt.title('Learning rate decay')
plt.xlabel('Episode')
plt.ylabel('Learning rate')
plt.savefig(os.path.join(OUT_PATH, 'learning_rate.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
if policy_lr_schedule:
plt.title('Learning rate decay')
plt.xlabel('Episode')
plt.ylabel('Learning rate')
plt.savefig(os.path.join(OUT_PATH, 'learning_rate.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
#############################
# Evaluation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment