Commit decea86e authored by lli's avatar lli
Browse files

Update the README

parent 3e7338b7
## Pytorch Reinforcement Learning for Yield Management
This repo contains the implementation of deep reinforcement learning algorithms (DQN, REINFORCE with baseline, A2C) for yield management.
## Installation
Install all the packages via:
```bash
pip install -r requirements.txt
```
## Project Organization
|── algorithms <- Source code of the deep reinforcement learning algorithms
| |── a2c.py <- Actor-critic algorithm (A2C)
| |── dqn.py <- Deep Q Network algorithm
| |── reinforce.py <- REINFORCE with baseline algorithm
|── dp <- Dynamic programming scripts for the calculation of test data set
| |── dp_solver.py <- Script of dynamic programming for the optimal solution
|── environment <- Gym environment and request generator
| |── requestGenerator.py <- Script to generate order requests
| |── wendtris.py <- Wendtris environment
|── params <- Model parameters (fixed)
| |── a2c_params.py <- A2C model fixed hyperparameters
| |── dqn_params.py <- DQN model fixed hyperparameters
| |── reinforce_params.py <- REINFORCE with baseline fixed hyperparameters
|── utils <- Utility functions
| |── cf_matrix.py <- Script to generate and visualize confusion matrix
| |── utils.py <- Script of the utility functions
|── .gitignore
|── main.py
|── reinforce_tune_lr.py <- Script to tune learning rate of REINFORCE with baseline
|── requirements.txt
|── train_a2c.py <- Script to train A2C models
|── train_dqn.py <- Script to train DQN models
|── train_reinforce.py <- Script to train REINFORCE with baseline models
## Usage
* To train the DQN agent
```shell
python train_dqn.py --save_path [file path] --n_hidden [number of neurons] --lr [learning rate] --policy [epsilon_greedy/boltzmann] --n_episode [training episodes] --batch_size [size]
```
* To train the REINFORCE with baseline agent
```shell
python train_reinforce.py --save_path --n_hidden --lr_policy --lr_value --n_episode
```
* To train the A2C agent
```shell
python train_a2c.py --save_path --n_hidden --lr --n_episode
```
......@@ -55,7 +55,6 @@ class PolicyNetwork(nn.Module):
loss = 0
for log_prob, value, Gt in zip(log_probs, state_values, returns):
advantage = Gt - value.item()
#advantage = (advantage - advantage.mean()) / advantage.std()
policy_loss = (-log_prob * advantage)
Gt = torch.unsqueeze(Gt, 0)
......
argon2-cffi @ file:///C:/ci/argon2-cffi_1613037959010/work
async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work
attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
absl-py==0.12.0
argon2-cffi==20.1.0
async-generator==1.10
attrs==20.3.0
backcall==0.2.0
bleach==3.3.0
cachetools==4.2.1
certifi==2020.12.5
cffi @ file:///C:/ci/cffi_1613247279197/work
cloudpickle @ file:///home/conda/feedstock_root/build_artifacts/cloudpickle_1598400192773/work
colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
cffi==1.14.5
chardet==4.0.0
cloudpickle==1.6.0
cycler==0.10.0
decorator @ file:///home/ktietz/src/ci/decorator_1611930055503/work
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
dataclasses==0.8
decorator==5.0.6
defusedxml==0.7.1
entrypoints==0.3
future @ file:///D:/bld/future_1610147454374/work
gym @ file:///D:/bld/gym_1608551988578/work
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1602276842396/work
ipykernel @ file:///C:/ci/ipykernel_1596190155316/work/dist/ipykernel-5.3.4-py3-none-any.whl
ipython @ file:///C:/ci/ipython_1614616640087/work
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
jedi==0.17.0
Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work
jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work
jupyter-core @ file:///C:/ci/jupyter_core_1612213356021/work
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
kiwisolver @ file:///C:/ci/kiwisolver_1612282606037/work
future==0.18.2
google-auth==1.29.0
google-auth-oauthlib==0.4.4
grpcio==1.37.0
gym==0.18.0
idna==2.10
importlib-metadata==3.10.0
ipykernel==5.5.3
ipython==7.16.1
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==2.11.3
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.12
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==1.1.1
matplotlib @ file:///C:/ci/matplotlib-suite_1613408055530/work
matplotlib==3.3.4
mistune==0.8.4
mkl-fft==1.3.0
mkl-random==1.1.1
mkl-service==2.3.0
nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
nbconvert @ file:///C:/ci/nbconvert_1601914925608/work
nbformat @ file:///tmp/build/80754af9/nbformat_1610738111109/work
nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
notebook @ file:///C:/ci/notebook_1611348396724/work
numpy @ file:///C:/ci/numpy_and_numpy_base_1603466732592/work
olefile==0.46
packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
pandas @ file:///C:/ci/pandas_1614711443373/work
pandocfilters @ file:///C:/ci/pandocfilters_1605102497129/work
parso @ file:///tmp/build/80754af9/parso_1607623074025/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
Pillow @ file:///C:/ci/pillow_1615224342392/work
prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1606344362066/work
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
pyglet @ file:///D:/bld/pyglet_1612861255678/work
Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work
pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
pyrsistent @ file:///C:/ci/pyrsistent_1600141795814/work
python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
pywin32==227
pywinpty==0.5.7
pyzmq==20.0.0
scipy @ file:///C:/ci/scipy_1614023125644/work
Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
sip==4.19.13
six @ file:///C:/ci/six_1605187374963/work
terminado==0.9.2
testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
tornado @ file:///C:/ci/tornado_1606942392901/work
traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
mpmath==1.2.1
nbclient==0.5.3
nbconvert==6.0.7
nbformat==5.1.3
nest-asyncio==1.5.1
nose==1.3.7
notebook==6.3.0
numpy==1.19.5
oauthlib==3.1.0
packaging==20.9
pandas==1.1.5
pandocfilters==1.4.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.2.0
prometheus-client==0.10.1
prompt-toolkit==3.0.18
protobuf==3.15.8
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
pyglet==1.5.0
Pygments==2.8.1
pyparsing==2.4.7
pyrsistent==0.17.3
pyslurm==19.5.0.0
python-dateutil==2.8.1
pytz==2021.1
pyzmq==22.0.3
qtconsole==5.0.3
QtPy==1.9.0
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.1
scipy==1.5.4
seaborn==0.11.1
Send2Trash==1.5.0
six==1.15.0
sympy==1.7.1
tensorboard==2.5.0
tensorboard-data-server==0.6.0
tensorboard-plugin-wit==1.8.0
terminado==0.9.4
testpath==0.4.4
threadpoolctl==2.1.0
torch==1.8.1+cu111
torchaudio==0.8.1
torchvision==0.9.1+cu111
tornado==6.1
traitlets==4.3.3
typing-extensions==3.7.4.3
urllib3==1.26.4
wcwidth==0.2.5
webencodings==0.5.1
wincertstore==0.2
zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work
Werkzeug==1.0.1
widgetsnbextension==3.5.1
zipp==3.4.1
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment