{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# NCL_eof_1_1.py\nCalculate EOFs of the Sea Level Pressure over the North Atlantic.\n\nThis script illustrates the following concepts:\n  - Calculating EOFs\n  - Drawing a time series plot\n  - Using coordinate subscripting to read a specified geographical region\n  - Rearranging longitude data to span -180 to 180\n  - Calculating symmetric contour intervals\n  - Drawing filled bars above and below a given reference line\n  - Drawing subtitles at the top of a plot\n  - Reordering an array\n\nSee following URLs to see the reproduced NCL plot & script:\n    - Original NCL script: https://www.ncl.ucar.edu/Applications/Scripts/eof_1.ncl\n    - Original NCL plot: https://www.ncl.ucar.edu/Applications/Images/eof_1_1_lg.png and https://www.ncl.ucar.edu/Applications/Images/eof_1_2_lg.png\n\nNote (1):\n    So-called original NCL plot \"eof_1_2_lg.png\" given in the above URL is likely not identical to what the given NCL original script generates. When the given NCL script is run, it generates a plot with identical data to that is plotted by this Python script.\n\nNote (2):\n    This script includes many optional diagnostic print statements to provide information about data slicing. To activate such print statements, the parameter \"debug\" should be set to True.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Import packages:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import xarray as xr\nimport numpy as np\n\nimport geocat.datafiles as gdf\nimport geocat.viz.util as gvutil\nfrom geocat.viz import cmaps as gvcmaps\nfrom geocat.comp import eofunc, eofunc_ts\n\nimport matplotlib.pyplot as plt\n\nimport cartopy.crs as ccrs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "User defined parameters and a convenience function:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# In order to specify region of the globe, time span, etc.\nlatS = 25.\nlatN = 80.\nlonL = -70.\nlonR = 40.\n\nyearStart = 1979\nyearEnd = 2003\n\nneof = 3  # number of EOFs\n\n# Set to True to activate diagnostic print statements throughout the code\ndebug = False\n\n\n# Convenience function to run diagnostic print statements when \"debug\" is set to True.\ndef print_debug(message):\n    if debug:\n        print(message)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Read in data:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Open a netCDF data file using xarray default engine and load the data into xarrays\nds = xr.open_dataset(gdf.get('netcdf_files/slp.mon.mean.nc'))\n\n# Print a content summary\nprint_debug('\\n\\nds.slp.attrs:\\n')\nprint_debug(ds.slp.attrs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Flip and sort longitude coordinates:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# To facilitate data subsetting\n\nprint_debug(\n    f'\\n\\nBefore flip, longitude range is [{ds[\"lon\"].min().data}, {ds[\"lon\"].max().data}].'\n)\n\nds[\"lon\"] = ((ds[\"lon\"] + 180) % 360) - 180\n\n# Sort longitudes, so that subset operations end up being simpler.\nds = ds.sortby(\"lon\")\n\nprint_debug(\n    f'\\n\\nAfter flip, longitude range is [{ds[\"lon\"].min().data}, {ds[\"lon\"].max().data}].'\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Place latitudes in increasing order:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# To facilitate data subsetting\n\nds = ds.sortby(\"lat\", ascending=True)\n\nprint_debug('\\n\\nAfter sorting latitude values, ds[\"lat\"] is:')\nprint_debug(ds[\"lat\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Limit data to the specified years:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "startDate = f'{yearStart}-01-01'\nendDate = f'{yearEnd}-12-01'\n\nds = ds.sel(time=slice(startDate, endDate))\nprint_debug('\\n\\nds:\\n\\n')\nprint_debug(ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Utility function:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Define a utility function for computing seasonal means (to mimmic NCL's month_to_season())\ndef month_to_season(xMon, season):\n    \"\"\" This function takes an xarray dataset containing monthly data spanning years and\n        returns a dataset with one sample per year, for a specified three-month season.\n\n        Time stamps are centered on the season, e.g. seasons='DJF' returns January timestamps.\n\n        If a calculated season's timestamp falls outside the original range of monthly values, then the calculated mean\n        is dropped.  For example, if the monthly data's time range is [Jan-2000, Dec-2003] and the season is \"DJF\", the\n        seasonal mean computed from the single month of Dec-2003 is dropped.\n    \"\"\"\n    startDate = xMon.time[0]\n    endDate = xMon.time[-1]\n    seasons_pd = {\n        'DJF': ('QS-DEC', 1),\n        'JFM': ('QS-JAN', 2),\n        'FMA': ('QS-FEB', 3),\n        'MAM': ('QS-MAR', 4),\n        'AMJ': ('QS-APR', 5),\n        'MJJ': ('QS-MAY', 6),\n        'JJA': ('QS-JUN', 7),\n        'JAS': ('QS-JUL', 8),\n        'ASO': ('QS-AUG', 9),\n        'SON': ('QS-SEP', 10),\n        'OND': ('QS-OCT', 11),\n        'NDJ': ('QS-NOV', 12)\n    }\n    try:\n        (season_pd, season_sel) = seasons_pd[season]\n    except KeyError:\n        raise ValueError(\"contributed: month_to_season: bad season: SEASON = \" +\n                         season)\n\n    # Compute the three-month means, moving time labels ahead to the middle month.\n    month_offset = 'MS'\n    xSeasons = xMon.resample(time=season_pd, loffset=month_offset).mean()\n\n    # Filter just the desired season, and trim to the desired time range.\n    xSea = xSeasons.sel(time=xSeasons.time.dt.month == season_sel)\n    xSea = xSea.sel(time=slice(startDate, endDate))\n    return xSea"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compute desired global seasonal mean using month_to_season()\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Choose the winter season (December-January-February)\nseason = \"DJF\"\nSLP = month_to_season(ds, season)\nprint_debug('\\n\\nSLP:\\n\\n')\nprint_debug(SLP)\n\n# Diagnostic plot: show slice of SLP\nsliceSLP = SLP.sel(lat=slice(latS, latN), lon=slice(lonL, lonR))\n\nprint_debug('\\n\\nsliceSLP:\\n')\nprint_debug(sliceSLP)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Create weights: sqrt(cos(lat))   [or sqrt(gw) ]\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "deg2rad = np.pi / 180.\nclat = SLP['lat'].astype(np.float64)\nclat = np.sqrt(np.cos(deg2rad * clat))\nprint_debug('\\n\\nclat:\\n')\nprint_debug(clat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Multiply SLP by weights:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Xarray will apply latitude-based weights to all longitudes and timesteps automatically.\n# This is called \"broadcasting\".\n\nwSLP = SLP\nwSLP['slp'] = clat * SLP['slp']\n\n# For now, metadata for slp must be copied over explicitly; it is not preserved by binary operators like multiplication.\nwSLP['slp'].attrs = ds['slp'].attrs\nwSLP['slp'].attrs['long_name'] = 'Wgt: ' + wSLP['slp'].attrs['long_name']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Subset data to the North Atlantic region:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "xw = wSLP.sel(lat=slice(latS, latN), lon=slice(lonL, lonR))\n\nprint_debug('\\n\\nxw:\\n\\n')\nprint_debug(xw.slp)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compute the EOFs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "eof = eofunc(xw[\"slp\"], neof, time_dim=1, meta=True)\n\nprint_debug('\\n\\neof:\\n\\n')\nprint_debug(eof)\n\neof_ts = eofunc_ts(xw[\"slp\"], eof, time_dim=1, meta=True)\n\nprint_debug('\\n\\neof_ts:\\n\\n')\nprint_debug(eof_ts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Normalize time series:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Sum spatial weights over the area used.\nnLon = xw.sizes[\"lon\"]\n\n# Bump the upper value of the slice, so that latitude values equal to latN are included.\nclat_subset = clat.sel(lat=slice(latS, latN + 0.01))\nweightTotal = clat_subset.sum() * nLon\neof_ts = eof_ts / weightTotal\n\nprint_debug('\\n\\neof_ts normalized:\\n\\n')\nprint_debug(eof_ts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Utility function:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Define a utility function for creating a contour plot.\ndef make_contour_plot(ax, dataset):\n    lat = dataset['lat']\n    lon = dataset['lon']\n    values = dataset.data\n\n    # Import an NCL colormap\n    cmap = gvcmaps.BlWhRe\n\n    # Specify contour levelstamam\n    v = np.linspace(-0.08, 0.08, 9, endpoint=True)\n\n    # The function contourf() produces fill colors, and contour() calculates contour label locations.\n    cplot = ax.contourf(lon,\n                        lat,\n                        values,\n                        levels=v,\n                        cmap=cmap,\n                        extend=\"both\",\n                        transform=ccrs.PlateCarree())\n    p = ax.contour(lon,\n                   lat,\n                   values,\n                   levels=v,\n                   linewidths=0.0,\n                   transform=ccrs.PlateCarree())\n\n    # Label the contours\n    ax.clabel(p, fontsize=8, fmt=\"%0.2f\", colors=\"black\")\n\n    # Add coastlines\n    ax.coastlines(linewidth=0.5)\n\n    # Use geocat.viz.util convenience function to add minor and major tick lines\n    gvutil.add_major_minor_ticks(ax,\n                                 x_minor_per_major=3,\n                                 y_minor_per_major=4,\n                                 labelsize=10)\n\n    # Use geocat.viz.util convenience function to set axes tick values\n    gvutil.set_axes_limits_and_ticks(ax,\n                                     xticks=[-60, -30, 0, 30],\n                                     yticks=[40, 60, 80])\n\n    # Use geocat.viz.util convenience function to make plots look like NCL plots, using latitude & longitude tick labels\n    gvutil.add_lat_lon_ticklabels(ax)\n\n    return cplot, ax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot (1): Draw a contour plot for each EOF\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Generate figure and axes using Cartopy projection  and set figure size (width, height) in inches\nfig, axs = plt.subplots(neof,\n                        1,\n                        subplot_kw={\"projection\": ccrs.PlateCarree()},\n                        figsize=(6, 10.6))\n\n# Add multiple axes to the figure as contour and contourf plots\nfor i in range(neof):\n    eof_single = eof.sel(evn=i)\n\n    # Create contour plot for the current axes\n    cplot, axs[i] = make_contour_plot(axs[i], eof_single)\n\n    # Use geocat.viz.util convenience function to add titles to left and right of the plot axis.\n    pct = eof.pcvar[i]\n    gvutil.set_titles_and_labels(axs[i],\n                                 lefttitle=f'EOF {i + 1}',\n                                 lefttitlefontsize=10,\n                                 righttitle=f'{pct:.1f}%',\n                                 righttitlefontsize=10)\n\n# Adjust subplot spacings and locations\nplt.subplots_adjust(bottom=0.07, top=0.95, hspace=0.15)\n\n# Add horizontal colorbar\ncbar = plt.colorbar(cplot,\n                    ax=axs,\n                    orientation='horizontal',\n                    shrink=0.9,\n                    pad=0.05,\n                    fraction=.02)\ncbar.ax.tick_params(labelsize=8)\n\n# Set a common title\naxs[0].set_title(f'SLP: DJF: {yearStart}-{yearEnd}', fontsize=14, y=1.12)\n\n# Show the plot\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Utility function:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Define a utility function for creating a bar plot.\n\n\ndef make_bar_plot(ax, dataset):\n    years = list(dataset.time.dt.year)\n    values = list(dataset.values)\n    colors = ['blue' if val < 0 else 'red' for val in values]\n\n    ax.bar(years, values, color=colors, width=1.0, edgecolor='black', linewidth=0.5)\n    ax.set_ylabel('Pa')\n\n    # Use geocat.viz.util convenience function to add minor and major tick lines\n    gvutil.add_major_minor_ticks(ax,\n                                 x_minor_per_major=4,\n                                 y_minor_per_major=5,\n                                 labelsize=8)\n\n    # Use geocat.viz.util convenience function to set axes tick values\n    gvutil.set_axes_limits_and_ticks(ax,\n                                     xticks=np.linspace(1980, 2000, 6),\n                                     xlim=[1978.5, 2003.5])\n\n    return ax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot (2): Produce a bar plot for each EOF.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Generate figure and axes using Cartopy projection and set figure size (width, height) in inches\nfig, axs = plt.subplots(neof, 1, constrained_layout=True, figsize=(6, 7.5))\n\n# Add multiple axes to the figure as bar-plots\nfor i in range(neof):\n    eof_single = eof_ts.sel(neval=i)\n\n    axs[i] = make_bar_plot(axs[i], eof_single)\n    pct = eof.pcvar[i]\n    gvutil.set_titles_and_labels(axs[i],\n                                 lefttitle=f'EOF {i + 1}',\n                                 lefttitlefontsize=10,\n                                 righttitle=f'{pct:.1f}%',\n                                 righttitlefontsize=10)\n\n# Set a common title\naxs[0].set_title(f'SLP: DJF: {yearStart}-{yearEnd}', fontsize=14, y=1.12)\n\n# Show the plot\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}