Commit f83a4918 authored by lli's avatar lli
Browse files

update

parent 1059735f
......@@ -55,7 +55,7 @@ class PolicyNetwork(nn.Module):
loss = 0
for log_prob, value, Gt in zip(log_probs, state_values, returns):
advantage = Gt - value.item()
policy_loss = (-log_prob * advantage)
policy_loss = -log_prob * advantage
Gt = torch.unsqueeze(Gt, 0)
value_loss = F.smooth_l1_loss(value, Gt)
......@@ -65,8 +65,7 @@ class PolicyNetwork(nn.Module):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
# self.scheduler.step()
return loss.item()
def get_action(self, state):
action_probs, state_value = self.predict(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment