Commit 8d49c5df authored by lli's avatar lli
Browse files

update

parent 3f9ca7fe
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')
class Actor(nn.Module):
def __init__(self, n_state, n_hidden, n_action):
"""
Actor network, takes the input state and outputs the action probabilities
It learns the optimal policy by updating the model using information provided by the critic
:param n_state: number of state of the environment
:param n_hidden: [list] number of hidden neurons in hidden layer
:param n_action: number of neurons of output layer
"""
super(Actor, self).__init__()
self.n_state = n_state
self.n_action = n_action
self.input = nn.Linear(self.n_state, n_hidden[0])
self.hidden = nn.ModuleList()
for k in range(len(n_hidden) - 1):
self.hidden.append(nn.Linear(n_hidden[k], n_hidden[k+1]))
self.out = nn.Linear(n_hidden[-1], self.n_action)
def forward(self, state):
output = F.relu(self.input(state))
for m in self.hidden:
output = m(output)
output = self.out(output)
distribution = Categorical(F.softmax(output, dim=-1))
return distribution
class Critic(nn.Module):
def __init__(self, n_state, n_hidden, n_action):
"""
Critic network evaluates how good it is to be at the input state by computing the value function.
The value guides the actor on how it should adjust.
"""
super(Critic, self).__init__()
self.n_state = n_state
self.n_action = n_action
self.input = nn.Linear(self.n_state, n_hidden[0])
self.hidden = nn.ModuleList()
for k in range(len(n_hidden) - 1):
self.hidden.append(nn.Linear(n_hidden[k], n_hidden[k+1]))
self.out = nn.Linear(n_hidden[-1], n_action)
def forward(self, state):
output = F.relu(self.input(state))
for m in self.hidden:
output = m(output)
value = self.out(output)
return value
......@@ -7,7 +7,7 @@ device = torch.device('cuda' if USE_CUDA else 'cpu')
# Define the policy net
class PolicyNetwork(nn.Module):
def __init__(self, n_state, n_action, n_hidden, lr, schedule_step=100000, schedule_rate=0.5, lr_schedule=False):
def __init__(self, n_state, n_hidden, n_action, lr):
'''
Initialize the policy neural network:
Use one hidden layer
......@@ -23,18 +23,10 @@ class PolicyNetwork(nn.Module):
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_action),
nn.Softmax(dim=-1),
nn.Softmax(dim=-1)
)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr)
self.lr_schedule = lr_schedule
if self.lr_schedule:
self.schedule_step = schedule_step
self.schedule_rate = schedule_rate
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.schedule_step,
gamma=self.schedule_rate)
def predict(self, state):
# Compute the action probabilities of state s using the learning rate
......@@ -59,8 +51,6 @@ class PolicyNetwork(nn.Module):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.lr_schedule:
self.scheduler.step()
return loss
def get_action(self, state):
......@@ -85,8 +75,6 @@ class ValueNetwork(nn.Module):
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, 1)
)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr)
......
from environment.wendtris import Wendtris
#############################
# Initialize the environment, with penalty factor 2
#############################s
env = Wendtris(20, 6, 6, 2)
#############################
# Model Params (fixed)
#############################
n_state = len(env.state) # Number of input
n_action = env.action_space.n # Number of output
gamma = 1 # Discounted factor
\ No newline at end of file
......@@ -3,7 +3,7 @@ from itertools import product
# Define different learning rates for learning rate tuning
parameters = dict(
policy_lr=[0.01, 0.001, 0.0001, 0.00001],
policy_lr=[0.001, 0.0001, 0.00001],
value_lr=[0.01, 0.001, 0.0001, 0.00001]
)
......
......@@ -105,8 +105,7 @@ v_running_losses = []
start_time = timer()
if policy_lr_schedule:
lr = []
#############################
# Training
#############################s
......@@ -214,14 +213,6 @@ plt.legend(handles=[num_penalty, avg_penalty], loc='best')
plt.savefig(os.path.join(OUT_PATH, 'penalties.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
# Plot learning rate
if policy_lr_schedule:
plt.plot(lr)
plt.title('Learning rate decay')
plt.xlabel('Episode')
plt.ylabel('Learning rate')
plt.savefig(os.path.join(OUT_PATH, 'learning_rate.png'), dpi=1200, transparent=True, bbox_inches='tight')
plt.close()
# Plot policy net loss
plt.plot(p_running_losses)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment