Commit 25c32d11 authored by lli's avatar lli
Browse files

update loss plot

parent 9f0e8ab6
......@@ -5,4 +5,5 @@ __pycache__/
results/
slurm/
runs/
evaluation/
\ No newline at end of file
evaluation/
.vscode/
\ No newline at end of file
......@@ -66,6 +66,7 @@ class PolicyNetwork(nn.Module):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
# self.scheduler.step()
def get_action(self, state):
......
......@@ -87,6 +87,7 @@ accepted_orders = []
start_time = timer()
losses = []
#############################
# Training
......@@ -127,8 +128,8 @@ for episode in range(n_episode):
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
policy_net.update(returns, log_probs, state_values)
loss = policy_net.update(returns, log_probs, state_values)
losses.append(loss)
print(f'Episode: {episode}, total reward: {total_reward_episode[episode]}, number of penalties: {num_no_capacity[episode]}, accepted orders: {accepted_orders[episode]}')
break
......@@ -162,6 +163,15 @@ plt.legend(handles=[num_penalty, avg_penalty], loc='best')
plt.savefig(os.path.join(OUT_PATH, 'penalties.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
# Plot loss
loss_smoothed = sliding_window(losses, len(losses))
plt.plot(loss_smoothed)
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Training loss')
plt.savefig(os.path.join(OUT_PATH, 'loss.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
#############################
# Evaluation
......@@ -202,6 +212,7 @@ save_list(total_reward_episode, EVA_FILE, 'total_reward_episode_train')
save_list(total_reward_episode_eva, EVA_FILE, 'total_reward_episode_eva')
save_list(num_no_capacity_eva, EVA_FILE, 'num_no_capacity_eva')
save_list(accepted_orders_eva, EVA_FILE, 'accepted_orders_eva')
save_list(losses, EVA_FILE, 'losses')
# Load optimal solution
......@@ -242,5 +253,5 @@ labels = ['True Neg', 'False Pos', 'False Neg', 'True Pos']
categories = ['Reject', 'Accept']
make_confusion_matrix(cf_matrix, group_names=labels, categories=categories, cmap='Blues')
plt.tight_layout()
plt.savefig(os.path.join(OUT_PATH, 'confusion_matrix.png'), transparent=True, bbox_inches='tight')
plt.savefig(os.path.join(OUT_PATH, 'confusion_matrix.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
......@@ -209,14 +209,14 @@ plt.close()
# Plot policy net loss
plt.plot(p_running_losses)
plt.plot(sliding_window(p_losses, len(p_losses)))
plt.title('Policy net losses')
plt.xlabel('Episode')
plt.savefig(os.path.join(OUT_PATH, 'policy_loss.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
# Plot value net loss
plt.plot(v_running_losses)
plt.plot(sliding_window(v_losses, len(v_losses)))
plt.title('Value net losses')
plt.xlabel('Episode')
plt.savefig(os.path.join(OUT_PATH, 'value_loss.png'), dpi=1200, transparent=True, bbox_inches='tight')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment