{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "mixed_logit_model.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "display_name": "Python 3 (Spyder)", "language": "python3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "TJHlxbR5kEe-" }, "source": [ "# Mixed Logit" ] }, { "cell_type": "markdown", "metadata": { "id": "qXO6ZtU_F2b4" }, "source": [ "The following examples provide step-by-step instructions to estimate mixed logit models using the xlogit package. You can interactively execute the code in this guide by opening it Google Colab using the following link:\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/arteagac/xlogit/blob/master/examples/mixed_logit_model.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "mra0NiIOFSie" }, "source": [ "## Install and import `xlogit` package" ] }, { "cell_type": "markdown", "metadata": { "id": "rvLSzJP1GfP1" }, "source": [ "Install `xlogit` using `pip` as shown below. In addition, import the package and check if GPU processing is available." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NQbZt7CVh8f_", "outputId": "b823e80f-fd47-4dd1-8656-3fd0d6a1e26a" }, "source": [ "!pip install xlogit\n", "from xlogit import MixedLogit\n", "MixedLogit.check_if_gpu_available()" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting xlogit\n", " Downloading xlogit-0.2.0-py3-none-any.whl (35 kB)\n", "Requirement already satisfied: numpy>=1.13.1 in /usr/local/lib/python3.7/dist-packages (from xlogit) (1.21.6)\n", "Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from xlogit) (1.4.1)\n", "Installing collected packages: xlogit\n", "Successfully installed xlogit-0.2.0\n", "1 GPU device(s) available. xlogit will use GPU processing\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 1 } ] }, { "cell_type": "markdown", "metadata": { "id": "MP77ezqVfvRI" }, "source": [ "## Swissmetro Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "BOWB3Lffg5Qc" }, "source": [ "\n", "The swissmetro dataset contains stated-preferences for three alternative transportation modes that include car, train and a newly introduced mode: the swissmetro. This dataset is commonly used for estimation examples with the `Biogeme` and `PyLogit` packages. The dataset is available at http://transp-or.epfl.ch/data/swissmetro.dat and [Bierlaire et. al., (2001)](https://transp-or.epfl.ch/documents/proceedings/BierAxhaAbay01.pdf) provides a detailed discussion of the data as wells as its context and collection process. The explanatory variables in this example include the travel time (`TT`) and cost `CO` for each of the three alternative modes." ] }, { "cell_type": "markdown", "metadata": { "id": "n4No84MAeFOM" }, "source": [ "### Read data" ] }, { "cell_type": "markdown", "metadata": { "id": "TEzmVzYDdLS8" }, "source": [ "The dataset is imported to the Python environment using `pandas`. Then, two types of samples, ones with a trip purpose different to commute or business and ones with an unknown choice, are filtered out. The original dataset contains 10,729 records, but after filtering, 6,768 records remain for following analysis. Finally, a new column that uniquely identifies each sample is added to the dataframe and the `CHOICE` column, which originally contains a numerical coding of the choices, is mapped to a description that is consistent with the alternatives in the column names. " ] }, { "cell_type": "code", "metadata": { "id": "4jqERhnWhGCc", "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "outputId": "6bbdca2a-1670-4836-c0d5-d16915ee9597" }, "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "df_wide = pd.read_table(\"http://transp-or.epfl.ch/data/swissmetro.dat\", sep='\\t')\n", "\n", "# Keep only observations for commute and business purposes that contain known choices\n", "df_wide = df_wide[(df_wide['PURPOSE'].isin([1, 3]) & (df_wide['CHOICE'] != 0))]\n", "\n", "df_wide['custom_id'] = np.arange(len(df_wide)) # Add unique identifier\n", "df_wide['CHOICE'] = df_wide['CHOICE'].map({1: 'TRAIN', 2:'SM', 3: 'CAR'})\n", "df_wide" ], "execution_count": 2, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " GROUP SURVEY SP ID PURPOSE FIRST TICKET WHO LUGGAGE AGE ... \\\n", "0 2 0 1 1 1 0 1 1 0 3 ... \n", "1 2 0 1 1 1 0 1 1 0 3 ... \n", "2 2 0 1 1 1 0 1 1 0 3 ... \n", "3 2 0 1 1 1 0 1 1 0 3 ... \n", "4 2 0 1 1 1 0 1 1 0 3 ... \n", "... ... ... .. ... ... ... ... ... ... ... ... \n", "8446 3 1 1 939 3 1 7 3 1 5 ... \n", "8447 3 1 1 939 3 1 7 3 1 5 ... \n", "8448 3 1 1 939 3 1 7 3 1 5 ... \n", "8449 3 1 1 939 3 1 7 3 1 5 ... \n", "8450 3 1 1 939 3 1 7 3 1 5 ... \n", "\n", " TRAIN_CO TRAIN_HE SM_TT SM_CO SM_HE SM_SEATS CAR_TT CAR_CO \\\n", "0 48 120 63 52 20 0 117 65 \n", "1 48 30 60 49 10 0 117 84 \n", "2 48 60 67 58 30 0 117 52 \n", "3 40 30 63 52 20 0 72 52 \n", "4 36 60 63 42 20 0 90 84 \n", "... ... ... ... ... ... ... ... ... \n", "8446 13 30 50 17 30 0 130 64 \n", "8447 12 30 53 16 10 0 80 80 \n", "8448 16 60 50 16 20 0 80 64 \n", "8449 16 30 53 17 30 0 80 104 \n", "8450 13 60 53 21 30 0 100 80 \n", "\n", " CHOICE custom_id \n", "0 SM 0 \n", "1 SM 1 \n", "2 SM 2 \n", "3 SM 3 \n", "4 SM 4 \n", "... ... ... \n", "8446 TRAIN 6763 \n", "8447 TRAIN 6764 \n", "8448 TRAIN 6765 \n", "8449 TRAIN 6766 \n", "8450 TRAIN 6767 \n", "\n", "[6768 rows x 29 columns]" ], "text/html": [ "\n", "
\n", " | GROUP | \n", "SURVEY | \n", "SP | \n", "ID | \n", "PURPOSE | \n", "FIRST | \n", "TICKET | \n", "WHO | \n", "LUGGAGE | \n", "AGE | \n", "... | \n", "TRAIN_CO | \n", "TRAIN_HE | \n", "SM_TT | \n", "SM_CO | \n", "SM_HE | \n", "SM_SEATS | \n", "CAR_TT | \n", "CAR_CO | \n", "CHOICE | \n", "custom_id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2 | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "... | \n", "48 | \n", "120 | \n", "63 | \n", "52 | \n", "20 | \n", "0 | \n", "117 | \n", "65 | \n", "SM | \n", "0 | \n", "
1 | \n", "2 | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "... | \n", "48 | \n", "30 | \n", "60 | \n", "49 | \n", "10 | \n", "0 | \n", "117 | \n", "84 | \n", "SM | \n", "1 | \n", "
2 | \n", "2 | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "... | \n", "48 | \n", "60 | \n", "67 | \n", "58 | \n", "30 | \n", "0 | \n", "117 | \n", "52 | \n", "SM | \n", "2 | \n", "
3 | \n", "2 | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "... | \n", "40 | \n", "30 | \n", "63 | \n", "52 | \n", "20 | \n", "0 | \n", "72 | \n", "52 | \n", "SM | \n", "3 | \n", "
4 | \n", "2 | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "... | \n", "36 | \n", "60 | \n", "63 | \n", "42 | \n", "20 | \n", "0 | \n", "90 | \n", "84 | \n", "SM | \n", "4 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
8446 | \n", "3 | \n", "1 | \n", "1 | \n", "939 | \n", "3 | \n", "1 | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "... | \n", "13 | \n", "30 | \n", "50 | \n", "17 | \n", "30 | \n", "0 | \n", "130 | \n", "64 | \n", "TRAIN | \n", "6763 | \n", "
8447 | \n", "3 | \n", "1 | \n", "1 | \n", "939 | \n", "3 | \n", "1 | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "... | \n", "12 | \n", "30 | \n", "53 | \n", "16 | \n", "10 | \n", "0 | \n", "80 | \n", "80 | \n", "TRAIN | \n", "6764 | \n", "
8448 | \n", "3 | \n", "1 | \n", "1 | \n", "939 | \n", "3 | \n", "1 | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "... | \n", "16 | \n", "60 | \n", "50 | \n", "16 | \n", "20 | \n", "0 | \n", "80 | \n", "64 | \n", "TRAIN | \n", "6765 | \n", "
8449 | \n", "3 | \n", "1 | \n", "1 | \n", "939 | \n", "3 | \n", "1 | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "... | \n", "16 | \n", "30 | \n", "53 | \n", "17 | \n", "30 | \n", "0 | \n", "80 | \n", "104 | \n", "TRAIN | \n", "6766 | \n", "
8450 | \n", "3 | \n", "1 | \n", "1 | \n", "939 | \n", "3 | \n", "1 | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "... | \n", "13 | \n", "60 | \n", "53 | \n", "21 | \n", "30 | \n", "0 | \n", "100 | \n", "80 | \n", "TRAIN | \n", "6767 | \n", "
6768 rows × 29 columns
\n", "\n", " | custom_id | \n", "alt | \n", "TT | \n", "CO | \n", "HE | \n", "AV | \n", "SEATS | \n", "GROUP | \n", "SURVEY | \n", "SP | \n", "... | \n", "TICKET | \n", "WHO | \n", "LUGGAGE | \n", "AGE | \n", "MALE | \n", "INCOME | \n", "GA | \n", "ORIGIN | \n", "DEST | \n", "CHOICE | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "TRAIN | \n", "112 | \n", "48 | \n", "120 | \n", "1 | \n", "0 | \n", "2 | \n", "0 | \n", "1 | \n", "... | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "0 | \n", "2 | \n", "0 | \n", "2 | \n", "1 | \n", "SM | \n", "
1 | \n", "0 | \n", "SM | \n", "63 | \n", "52 | \n", "20 | \n", "1 | \n", "0 | \n", "2 | \n", "0 | \n", "1 | \n", "... | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "0 | \n", "2 | \n", "0 | \n", "2 | \n", "1 | \n", "SM | \n", "
2 | \n", "0 | \n", "CAR | \n", "117 | \n", "65 | \n", "0 | \n", "1 | \n", "0 | \n", "2 | \n", "0 | \n", "1 | \n", "... | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "0 | \n", "2 | \n", "0 | \n", "2 | \n", "1 | \n", "SM | \n", "
3 | \n", "1 | \n", "TRAIN | \n", "103 | \n", "48 | \n", "30 | \n", "1 | \n", "0 | \n", "2 | \n", "0 | \n", "1 | \n", "... | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "0 | \n", "2 | \n", "0 | \n", "2 | \n", "1 | \n", "SM | \n", "
4 | \n", "1 | \n", "SM | \n", "60 | \n", "49 | \n", "10 | \n", "1 | \n", "0 | \n", "2 | \n", "0 | \n", "1 | \n", "... | \n", "1 | \n", "1 | \n", "0 | \n", "3 | \n", "0 | \n", "2 | \n", "0 | \n", "2 | \n", "1 | \n", "SM | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
20299 | \n", "6766 | \n", "SM | \n", "53 | \n", "17 | \n", "30 | \n", "1 | \n", "0 | \n", "3 | \n", "1 | \n", "1 | \n", "... | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "1 | \n", "2 | \n", "0 | \n", "1 | \n", "2 | \n", "TRAIN | \n", "
20300 | \n", "6766 | \n", "CAR | \n", "80 | \n", "104 | \n", "0 | \n", "1 | \n", "0 | \n", "3 | \n", "1 | \n", "1 | \n", "... | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "1 | \n", "2 | \n", "0 | \n", "1 | \n", "2 | \n", "TRAIN | \n", "
20301 | \n", "6767 | \n", "TRAIN | \n", "108 | \n", "13 | \n", "60 | \n", "1 | \n", "0 | \n", "3 | \n", "1 | \n", "1 | \n", "... | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "1 | \n", "2 | \n", "0 | \n", "1 | \n", "2 | \n", "TRAIN | \n", "
20302 | \n", "6767 | \n", "SM | \n", "53 | \n", "21 | \n", "30 | \n", "1 | \n", "0 | \n", "3 | \n", "1 | \n", "1 | \n", "... | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "1 | \n", "2 | \n", "0 | \n", "1 | \n", "2 | \n", "TRAIN | \n", "
20303 | \n", "6767 | \n", "CAR | \n", "100 | \n", "80 | \n", "0 | \n", "1 | \n", "0 | \n", "3 | \n", "1 | \n", "1 | \n", "... | \n", "7 | \n", "3 | \n", "1 | \n", "5 | \n", "1 | \n", "2 | \n", "0 | \n", "1 | \n", "2 | \n", "TRAIN | \n", "
20304 rows × 23 columns
\n", "\n", " | choice | \n", "id | \n", "alt | \n", "pf | \n", "cl | \n", "loc | \n", "wk | \n", "tod | \n", "seas | \n", "chid | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "1 | \n", "1 | \n", "7 | \n", "5 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "
1 | \n", "0 | \n", "1 | \n", "2 | \n", "9 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "
2 | \n", "0 | \n", "1 | \n", "3 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "
3 | \n", "1 | \n", "1 | \n", "4 | \n", "0 | \n", "5 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "1 | \n", "
4 | \n", "0 | \n", "1 | \n", "1 | \n", "7 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "2 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
17227 | \n", "0 | \n", "361 | \n", "4 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "4307 | \n", "
17228 | \n", "1 | \n", "361 | \n", "1 | \n", "9 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "4308 | \n", "
17229 | \n", "0 | \n", "361 | \n", "2 | \n", "7 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "4308 | \n", "
17230 | \n", "0 | \n", "361 | \n", "3 | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "4308 | \n", "
17231 | \n", "0 | \n", "361 | \n", "4 | \n", "0 | \n", "5 | \n", "1 | \n", "0 | \n", "1 | \n", "0 | \n", "4308 | \n", "
17232 rows × 10 columns
\n", "\n", " | id | \n", "alt | \n", "choice | \n", "income | \n", "price | \n", "catch | \n", "
---|---|---|---|---|---|---|
0 | \n", "1 | \n", "beach | \n", "0 | \n", "7083.33170 | \n", "157.930 | \n", "0.0678 | \n", "
1 | \n", "1 | \n", "boat | \n", "0 | \n", "7083.33170 | \n", "157.930 | \n", "0.2601 | \n", "
2 | \n", "1 | \n", "charter | \n", "1 | \n", "7083.33170 | \n", "182.930 | \n", "0.5391 | \n", "
3 | \n", "1 | \n", "pier | \n", "0 | \n", "7083.33170 | \n", "157.930 | \n", "0.0503 | \n", "
4 | \n", "2 | \n", "beach | \n", "0 | \n", "1249.99980 | \n", "15.114 | \n", "0.1049 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
4723 | \n", "1181 | \n", "pier | \n", "0 | \n", "416.66668 | \n", "36.636 | \n", "0.4522 | \n", "
4724 | \n", "1182 | \n", "beach | \n", "0 | \n", "6250.00130 | \n", "339.890 | \n", "0.2537 | \n", "
4725 | \n", "1182 | \n", "boat | \n", "1 | \n", "6250.00130 | \n", "235.436 | \n", "0.6817 | \n", "
4726 | \n", "1182 | \n", "charter | \n", "0 | \n", "6250.00130 | \n", "260.436 | \n", "2.3014 | \n", "
4727 | \n", "1182 | \n", "pier | \n", "0 | \n", "6250.00130 | \n", "339.890 | \n", "0.1498 | \n", "
4728 rows × 6 columns
\n", "\n", " | person_id | \n", "choice_id | \n", "alt | \n", "choice | \n", "price | \n", "opcost | \n", "range | \n", "ev | \n", "gas | \n", "hybrid | \n", "hiperf | \n", "medhiperf | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "-4.6763 | \n", "-47.43 | \n", "0.0 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "
1 | \n", "1 | \n", "1 | \n", "2 | \n", "1 | \n", "-5.7209 | \n", "-27.43 | \n", "1.3 | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "
2 | \n", "1 | \n", "1 | \n", "3 | \n", "0 | \n", "-8.7960 | \n", "-32.41 | \n", "1.2 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "
3 | \n", "1 | \n", "2 | \n", "1 | \n", "1 | \n", "-3.3768 | \n", "-4.89 | \n", "1.3 | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "
4 | \n", "1 | \n", "2 | \n", "2 | \n", "0 | \n", "-9.0336 | \n", "-30.19 | \n", "0.0 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
4447 | \n", "100 | \n", "1483 | \n", "2 | \n", "0 | \n", "-2.8036 | \n", "-14.45 | \n", "1.6 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
4448 | \n", "100 | \n", "1483 | \n", "3 | \n", "0 | \n", "-1.9360 | \n", "-54.76 | \n", "0.0 | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "
4449 | \n", "100 | \n", "1484 | \n", "1 | \n", "1 | \n", "-2.4054 | \n", "-50.57 | \n", "0.0 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "
4450 | \n", "100 | \n", "1484 | \n", "2 | \n", "0 | \n", "-5.2795 | \n", "-21.25 | \n", "0.0 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "
4451 | \n", "100 | \n", "1484 | \n", "3 | \n", "0 | \n", "-6.0705 | \n", "-25.41 | \n", "1.4 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
4452 rows × 12 columns
\n", "