Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
lli
YM-Seminar
Commits
decea86e
Commit
decea86e
authored
Apr 30, 2021
by
lli
Browse files
Update the README
parent
3e7338b7
Changes
3
Hide whitespace changes
Inline
Side-by-side
README.md
0 → 100644
View file @
decea86e
## Pytorch Reinforcement Learning for Yield Management
This repo contains the implementation of deep reinforcement learning algorithms (DQN, REINFORCE with baseline, A2C) for yield management.
## Installation
Install all the packages via:
```
bash
pip
install
-r
requirements.txt
```
## Project Organization
|── algorithms <- Source code of the deep reinforcement learning algorithms
| |── a2c.py <- Actor-critic algorithm (A2C)
| |── dqn.py <- Deep Q Network algorithm
| |── reinforce.py <- REINFORCE with baseline algorithm
|── dp <- Dynamic programming scripts for the calculation of test data set
| |── dp_solver.py <- Script of dynamic programming for the optimal solution
|── environment <- Gym environment and request generator
| |── requestGenerator.py <- Script to generate order requests
| |── wendtris.py <- Wendtris environment
|── params <- Model parameters (fixed)
| |── a2c_params.py <- A2C model fixed hyperparameters
| |── dqn_params.py <- DQN model fixed hyperparameters
| |── reinforce_params.py <- REINFORCE with baseline fixed hyperparameters
|── utils <- Utility functions
| |── cf_matrix.py <- Script to generate and visualize confusion matrix
| |── utils.py <- Script of the utility functions
|── .gitignore
|── main.py
|── reinforce_tune_lr.py <- Script to tune learning rate of REINFORCE with baseline
|── requirements.txt
|── train_a2c.py <- Script to train A2C models
|── train_dqn.py <- Script to train DQN models
|── train_reinforce.py <- Script to train REINFORCE with baseline models
## Usage
*
To train the DQN agent
```
shell
python train_dqn.py
--save_path
[
file path]
--n_hidden
[
number of neurons]
--lr
[
learning rate]
--policy
[
epsilon_greedy/boltzmann]
--n_episode
[
training episodes]
--batch_size
[
size]
```
*
To train the REINFORCE with baseline agent
```
shell
python train_reinforce.py
--save_path
--n_hidden
--lr_policy
--lr_value
--n_episode
```
*
To train the A2C agent
```
shell
python train_a2c.py
--save_path
--n_hidden
--lr
--n_episode
```
algorithms/a2c.py
View file @
decea86e
...
...
@@ -55,7 +55,6 @@ class PolicyNetwork(nn.Module):
loss
=
0
for
log_prob
,
value
,
Gt
in
zip
(
log_probs
,
state_values
,
returns
):
advantage
=
Gt
-
value
.
item
()
#advantage = (advantage - advantage.mean()) / advantage.std()
policy_loss
=
(
-
log_prob
*
advantage
)
Gt
=
torch
.
unsqueeze
(
Gt
,
0
)
...
...
requirements.txt
View file @
decea86e
argon2-cffi
@
file:///C:/ci/argon2-cffi_1613037959010/work
async-generator
@
file:///home/ktietz/src/ci/async_generator_1611927993394/work
attrs
@
file:///tmp/build/80754af9/attrs_1604765588209/work
backcall
@
file:///home/ktietz/src/ci/backcall_1611930011877/work
bleach
@
file:///tmp/build/80754af9/bleach_1612211392645/work
absl-py
==0.12.0
argon2-cffi
==20.1.0
async-generator
==1.10
attrs
==20.3.0
backcall
==0.2.0
bleach
==3.3.0
cachetools
==4.2.1
certifi
==2020.12.5
cffi
@
file:///C:/ci/cffi_1613247279197/work
c
loudpickle
@
file:///home/conda/feedstock_root/build_artifacts/cloudpickle_1598400192773/work
c
o
lo
rama
@
file:///tmp/build/80754af9/colorama_1607707115595/work
cffi
==1.14.5
c
hardet
==4.0.0
clo
udpickle
==1.6.0
cycler
==0.10.0
decorator
@
file:///home/ktietz/src/ci/decorator_1611930055503/work
defusedxml
@
file:///tmp/build/80754af9/defusedxml_1615228127516/work
dataclasses
==0.8
decorator
==5.0.6
defusedxml
==0.7.1
entrypoints
==0.3
future
@
file:///D:/bld/future_1610147454374/work
gym
@
file:///D:/bld/gym_1608551988578/work
importlib-metadata
@
file:///tmp/build/80754af9/importlib-metadata_1602276842396/work
ipykernel
@
file:///C:/ci/ipykernel_1596190155316/work/dist/ipykernel-5.3.4-py3-none-any.whl
ipython
@
file:///C:/ci/ipython_1614616640087/work
ipython-genutils
@
file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
jedi
==0.17.0
Jinja2
@
file:///tmp/build/80754af9/jinja2_1612213139570/work
jsonschema
@
file:///tmp/build/80754af9/jsonschema_1602607155483/work
jupyter-client
@
file:///tmp/build/80754af9/jupyter_client_1601311786391/work
jupyter-core
@
file:///C:/ci/jupyter_core_1612213356021/work
jupyterlab-pygments
@
file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
kiwisolver
@
file:///C:/ci/kiwisolver_1612282606037/work
future
==0.18.2
google-auth
==1.29.0
google-auth-oauthlib
==0.4.4
grpcio
==1.37.0
gym
==0.18.0
idna
==2.10
importlib-metadata
==3.10.0
ipykernel
==5.5.3
ipython
==7.16.1
ipython-genutils
==0.2.0
ipywidgets
==7.6.3
jedi
==0.18.0
Jinja2
==2.11.3
joblib
==1.0.1
jsonschema
==3.2.0
jupyter
==1.0.0
jupyter-client
==6.1.12
jupyter-console
==6.4.0
jupyter-core
==4.7.1
jupyterlab-pygments
==0.1.2
jupyterlab-widgets
==1.0.0
kiwisolver
==1.3.1
Markdown
==3.3.4
MarkupSafe
==1.1.1
matplotlib
@
file:///C:/ci/matplotlib-suite_1613408055530/work
matplotlib
==3.3.4
mistune
==0.8.4
mkl-fft
==1.3.0
mkl-random
==1.1.1
mkl-service
==2.3.0
nbclient
@
file:///tmp/build/80754af9/nbclient_1614364831625/work
nbconvert
@
file:///C:/ci/nbconvert_1601914925608/work
nbformat
@
file:///tmp/build/80754af9/nbformat_1610738111109/work
nest-asyncio
@
file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
notebook
@
file:///C:/ci/notebook_1611348396724/work
numpy
@
file:///C:/ci/numpy_and_numpy_base_1603466732592/work
olefile
==0.46
packaging
@
file:///tmp/build/80754af9/packaging_1611952188834/work
pandas
@
file:///C:/ci/pandas_1614711443373/work
pandocfilters
@
file:///C:/ci/pandocfilters_1605102497129/work
parso
@
file:///tmp/build/80754af9/parso_1607623074025/work
pickleshare
@
file:///tmp/build/80754af9/pickleshare_1606932040724/work
Pillow
@
file:///C:/ci/pillow_1615224342392/work
prometheus-client
@
file:///tmp/build/80754af9/prometheus_client_1606344362066/work
prompt-toolkit
@
file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
pycparser
@
file:///tmp/build/80754af9/pycparser_1594388511720/work
pyglet
@
file:///D:/bld/pyglet_1612861255678/work
Pygments
@
file:///tmp/build/80754af9/pygments_1615143339740/work
pyparsing
@
file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
pyrsistent
@
file:///C:/ci/pyrsistent_1600141795814/work
python-dateutil
@
file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
pytz
@
file:///tmp/build/80754af9/pytz_1612215392582/work
pywin32
==227
pywinpty
==0.5.7
pyzmq
==20.0.0
scipy
@
file:///C:/ci/scipy_1614023125644/work
Send2Trash
@
file:///tmp/build/80754af9/send2trash_1607525499227/work
sip
==4.19.13
six
@
file:///C:/ci/six_1605187374963/work
terminado
==0.9.2
testpath
@
file:///home/ktietz/src/ci/testpath_1611930608132/work
tornado
@
file:///C:/ci/tornado_1606942392901/work
traitlets
@
file:///home/ktietz/src/ci/traitlets_1611929699868/work
typing-extensions
@
file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
wcwidth
@
file:///tmp/build/80754af9/wcwidth_1593447189090/work
mpmath
==1.2.1
nbclient
==0.5.3
nbconvert
==6.0.7
nbformat
==5.1.3
nest-asyncio
==1.5.1
nose
==1.3.7
notebook
==6.3.0
numpy
==1.19.5
oauthlib
==3.1.0
packaging
==20.9
pandas
==1.1.5
pandocfilters
==1.4.3
parso
==0.8.2
pexpect
==4.8.0
pickleshare
==0.7.5
Pillow
==7.2.0
prometheus-client
==0.10.1
prompt-toolkit
==3.0.18
protobuf
==3.15.8
ptyprocess
==0.7.0
pyasn1
==0.4.8
pyasn1-modules
==0.2.8
pycparser
==2.20
pyglet
==1.5.0
Pygments
==2.8.1
pyparsing
==2.4.7
pyrsistent
==0.17.3
pyslurm
==19.5.0.0
python-dateutil
==2.8.1
pytz
==2021.1
pyzmq
==22.0.3
qtconsole
==5.0.3
QtPy
==1.9.0
requests
==2.25.1
requests-oauthlib
==1.3.0
rsa
==4.7.2
scikit-learn
==0.24.1
scipy
==1.5.4
seaborn
==0.11.1
Send2Trash
==1.5.0
six
==1.15.0
sympy
==1.7.1
tensorboard
==2.5.0
tensorboard-data-server
==0.6.0
tensorboard-plugin-wit
==1.8.0
terminado
==0.9.4
testpath
==0.4.4
threadpoolctl
==2.1.0
torch
==1.8.1+cu111
torchaudio
==0.8.1
torchvision
==0.9.1+cu111
tornado
==6.1
traitlets
==4.3.3
typing-extensions
==3.7.4.3
urllib3
==1.26.4
wcwidth
==0.2.5
webencodings
==0.5.1
wincertstore
==0.2
zipp
@
file:///tmp/build/80754af9/zipp_1604001098328/work
Werkzeug
==1.0.1
widgetsnbextension
==3.5.1
zipp
==3.4.1
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment