Commit decea86e authored by lli's avatar lli
Browse files

Update the README

parent 3e7338b7
## Pytorch Reinforcement Learning for Yield Management
This repo contains the implementation of deep reinforcement learning algorithms (DQN, REINFORCE with baseline, A2C) for yield management.
## Installation
Install all the packages via:
```bash
pip install -r requirements.txt
```
## Project Organization
|── algorithms <- Source code of the deep reinforcement learning algorithms
| |── a2c.py <- Actor-critic algorithm (A2C)
| |── dqn.py <- Deep Q Network algorithm
| |── reinforce.py <- REINFORCE with baseline algorithm
|── dp <- Dynamic programming scripts for the calculation of test data set
| |── dp_solver.py <- Script of dynamic programming for the optimal solution
|── environment <- Gym environment and request generator
| |── requestGenerator.py <- Script to generate order requests
| |── wendtris.py <- Wendtris environment
|── params <- Model parameters (fixed)
| |── a2c_params.py <- A2C model fixed hyperparameters
| |── dqn_params.py <- DQN model fixed hyperparameters
| |── reinforce_params.py <- REINFORCE with baseline fixed hyperparameters
|── utils <- Utility functions
| |── cf_matrix.py <- Script to generate and visualize confusion matrix
| |── utils.py <- Script of the utility functions
|── .gitignore
|── main.py
|── reinforce_tune_lr.py <- Script to tune learning rate of REINFORCE with baseline
|── requirements.txt
|── train_a2c.py <- Script to train A2C models
|── train_dqn.py <- Script to train DQN models
|── train_reinforce.py <- Script to train REINFORCE with baseline models
## Usage
* To train the DQN agent
```shell
python train_dqn.py --save_path [file path] --n_hidden [number of neurons] --lr [learning rate] --policy [epsilon_greedy/boltzmann] --n_episode [training episodes] --batch_size [size]
```
* To train the REINFORCE with baseline agent
```shell
python train_reinforce.py --save_path --n_hidden --lr_policy --lr_value --n_episode
```
* To train the A2C agent
```shell
python train_a2c.py --save_path --n_hidden --lr --n_episode
```
...@@ -55,7 +55,6 @@ class PolicyNetwork(nn.Module): ...@@ -55,7 +55,6 @@ class PolicyNetwork(nn.Module):
loss = 0 loss = 0
for log_prob, value, Gt in zip(log_probs, state_values, returns): for log_prob, value, Gt in zip(log_probs, state_values, returns):
advantage = Gt - value.item() advantage = Gt - value.item()
#advantage = (advantage - advantage.mean()) / advantage.std()
policy_loss = (-log_prob * advantage) policy_loss = (-log_prob * advantage)
Gt = torch.unsqueeze(Gt, 0) Gt = torch.unsqueeze(Gt, 0)
......
argon2-cffi @ file:///C:/ci/argon2-cffi_1613037959010/work absl-py==0.12.0
async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work argon2-cffi==20.1.0
attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work async-generator==1.10
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work attrs==20.3.0
bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work backcall==0.2.0
bleach==3.3.0
cachetools==4.2.1
certifi==2020.12.5 certifi==2020.12.5
cffi @ file:///C:/ci/cffi_1613247279197/work cffi==1.14.5
cloudpickle @ file:///home/conda/feedstock_root/build_artifacts/cloudpickle_1598400192773/work chardet==4.0.0
colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work cloudpickle==1.6.0
cycler==0.10.0 cycler==0.10.0
decorator @ file:///home/ktietz/src/ci/decorator_1611930055503/work dataclasses==0.8
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work decorator==5.0.6
defusedxml==0.7.1
entrypoints==0.3 entrypoints==0.3
future @ file:///D:/bld/future_1610147454374/work future==0.18.2
gym @ file:///D:/bld/gym_1608551988578/work google-auth==1.29.0
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1602276842396/work google-auth-oauthlib==0.4.4
ipykernel @ file:///C:/ci/ipykernel_1596190155316/work/dist/ipykernel-5.3.4-py3-none-any.whl grpcio==1.37.0
ipython @ file:///C:/ci/ipython_1614616640087/work gym==0.18.0
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work idna==2.10
jedi==0.17.0 importlib-metadata==3.10.0
Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work ipykernel==5.5.3
jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work ipython==7.16.1
jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work ipython-genutils==0.2.0
jupyter-core @ file:///C:/ci/jupyter_core_1612213356021/work ipywidgets==7.6.3
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work jedi==0.18.0
kiwisolver @ file:///C:/ci/kiwisolver_1612282606037/work Jinja2==2.11.3
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.12
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==1.1.1 MarkupSafe==1.1.1
matplotlib @ file:///C:/ci/matplotlib-suite_1613408055530/work matplotlib==3.3.4
mistune==0.8.4 mistune==0.8.4
mkl-fft==1.3.0 mpmath==1.2.1
mkl-random==1.1.1 nbclient==0.5.3
mkl-service==2.3.0 nbconvert==6.0.7
nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work nbformat==5.1.3
nbconvert @ file:///C:/ci/nbconvert_1601914925608/work nest-asyncio==1.5.1
nbformat @ file:///tmp/build/80754af9/nbformat_1610738111109/work nose==1.3.7
nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work notebook==6.3.0
notebook @ file:///C:/ci/notebook_1611348396724/work numpy==1.19.5
numpy @ file:///C:/ci/numpy_and_numpy_base_1603466732592/work oauthlib==3.1.0
olefile==0.46 packaging==20.9
packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work pandas==1.1.5
pandas @ file:///C:/ci/pandas_1614711443373/work pandocfilters==1.4.3
pandocfilters @ file:///C:/ci/pandocfilters_1605102497129/work parso==0.8.2
parso @ file:///tmp/build/80754af9/parso_1607623074025/work pexpect==4.8.0
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work pickleshare==0.7.5
Pillow @ file:///C:/ci/pillow_1615224342392/work Pillow==7.2.0
prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1606344362066/work prometheus-client==0.10.1
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work prompt-toolkit==3.0.18
pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work protobuf==3.15.8
pyglet @ file:///D:/bld/pyglet_1612861255678/work ptyprocess==0.7.0
Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work pyasn1==0.4.8
pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work pyasn1-modules==0.2.8
pyrsistent @ file:///C:/ci/pyrsistent_1600141795814/work pycparser==2.20
python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work pyglet==1.5.0
pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work Pygments==2.8.1
pywin32==227 pyparsing==2.4.7
pywinpty==0.5.7 pyrsistent==0.17.3
pyzmq==20.0.0 pyslurm==19.5.0.0
scipy @ file:///C:/ci/scipy_1614023125644/work python-dateutil==2.8.1
Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work pytz==2021.1
sip==4.19.13 pyzmq==22.0.3
six @ file:///C:/ci/six_1605187374963/work qtconsole==5.0.3
terminado==0.9.2 QtPy==1.9.0
testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work requests==2.25.1
tornado @ file:///C:/ci/tornado_1606942392901/work requests-oauthlib==1.3.0
traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work rsa==4.7.2
typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work scikit-learn==0.24.1
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work scipy==1.5.4
seaborn==0.11.1
Send2Trash==1.5.0
six==1.15.0
sympy==1.7.1
tensorboard==2.5.0
tensorboard-data-server==0.6.0
tensorboard-plugin-wit==1.8.0
terminado==0.9.4
testpath==0.4.4
threadpoolctl==2.1.0
torch==1.8.1+cu111
torchaudio==0.8.1
torchvision==0.9.1+cu111
tornado==6.1
traitlets==4.3.3
typing-extensions==3.7.4.3
urllib3==1.26.4
wcwidth==0.2.5
webencodings==0.5.1 webencodings==0.5.1
wincertstore==0.2 Werkzeug==1.0.1
zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work widgetsnbextension==3.5.1
zipp==3.4.1
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment