Distributed Machine Learning Pipeline: NDVI ~ Soil + Weather Dynamics

This tutorial walks thru a machine learning pipeline. This example excludes the Extract component in the often referenced ETL (Extract, Transform, Learn) machine learning nomenclature. The overall goal of this analysis is to predict NDVI dynamics from soil and lagged precipitation, temperature, and vapor pressure deficit observations. The brief outline of the tutorial is:

  1. Read and transform the NDVI, Soil, and Weather data.
  2. Merge the three datasets and add 26 weekly lags of precipitation, vpd, and temperature as features.
  3. Shuffle and split data into three groups:
    • 3% for hyperparameter optimization (Group 1)
    • 97 % for final model
      • 77.6% (97% * 80%) for final model training (Group 2)
      • 19.4% (97% * 20%) for final model testing (validation) (Group 3)
  4. Optimize the hyperparamters in an XGBoost model (Xtreme Gradient Boosting) using a small subset of the data.
  5. Using the "best fit" hyperparameters, train the model 77.6% of the data (Group 2).
  6. Validation with the test (hold-out) data (19.4% - Group 3)

Table of Contents

  1. Build a Distributed Cluster
  2. Preprocess, Transform, and Merge the Data
  3. Machine Learning: XGBoost Model
  4. Interpreting the Model
In [1]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import dask_jobqueue as jq
import dask
from dask import dataframe as ddf
from dask import array as da
import os
from dask.distributed import Client, wait
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from tqdm.notebook import tqdm

Build a Distributed Cluster

We will use dask-jobqueue to launch and scale a cluster. For a more detailed example of how this works, please see the other tutorials in the SCINet Geospatial 2020 Workshop. For a quick review, the workflow for defining a cluster and scaling is:

  1. Dask-jobqueue submits jobs to Slurm with an sbatch script
  2. The sbatch scripts define the dask workers with the following info:
    • Partition to launch jobs/workers (partition)
    • X number of processes (i.e. dask workers) per sbatch script (num_processes).
    • Number of threads/cpus per dask worker (num_threads_per_process)
    • Memory allocated per sbatch scipt (mem), which is spread evenly between the dask workers.
  3. Scale the cluster to the total # of workers. Needs to be a multiple of num_processes.

In this example, we are defining one process (dask worker) per sbatch script. Each process will have 40 cpus (an entire node). We then scale the cluster to 9 workers, which total 360 threads at 1.15TB of memory.

In [2]:
partition='short,brief-low'
num_processes = 1
num_threads_per_process = 40
mem = 3.2*num_processes*num_threads_per_process
n_cores_per_job = num_processes*num_threads_per_process
container = '/lustre/project/geospatial_tutorials/wg_2020_ws/data_science_im_rs_vSCINetGeoWS_2020.sif'
env = 'py_geo'
clust = jq.SLURMCluster(queue=partition,
                        processes=num_processes,
                        memory=str(mem)+'GB',
                        cores=n_cores_per_job,
                        interface='ib0',
                        local_directory='$TMPDIR',
                        death_timeout=30,
                        python="singularity exec {} /opt/conda/envs/{}/bin/python".format(container,env),
                        walltime='01:30:00',
                        job_extra=["--output=/dev/null","--error=/dev/null"])
cl=Client(clust)
cl
print('The Dask dashboard address is: /user/'+os.environ['USER']+'/proxy/'+cl.dashboard_link.split(':')[-1].split('/')[0]+'/status')
The Dask dashboard address is: /user/rowan.gaffney/proxy/8787/status

View Cluster Dashboard To view the cluster with the dask dashboard interaface click the dask icon on the left menu pane. Copy and paste the above dashboard address (in the form of /user/{User.Name}/proxy/{port#}/status) into the address bar. Then click on the "Workers", "Progress", "Task Stream", and "CPU" to open those tabs. Drag and arrange in convineint layout on right-hand side of the screen. Note these panes should be mostly blank as we have yet to scale the cluster, which is the next step below.

Dask Icon:

Scale the Cluster to 9 workers (40 cpus per worker). This may take 5-20 seconds to complete.

In [3]:
#scale the cluster
n_workers=9
clust.scale(n=n_workers*num_processes)
#Wait for the cluster to load, show progress bar.
with tqdm(total=n_workers*num_processes) as pbar:
    while (((cl.status == "running") and (len(cl.scheduler_info()["workers"]) < n_workers*num_processes))):
        pbar.update(len(cl.scheduler_info()["workers"])-pbar.n)
    pbar.update(len(cl.scheduler_info()["workers"])-pbar.n)
cl

Out[3]:

Client

Cluster

  • Workers: 9
  • Cores: 360
  • Memory: 1.15 TB
In [4]:
#Lets see the workers are running in SLURM
me = os.environ['USER']
!squeue -u $me
             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON) 
           5001336 brief-low jupyterh rowan.ga  R    1:10:40      1 ceres18-compute-3 
           5001366     short dask-wor rowan.ga  R       0:08      1 ceres14-compute-48 
           5001367     short dask-wor rowan.ga  R       0:08      1 ceres14-compute-49 
           5001368     short dask-wor rowan.ga  R       0:08      1 ceres19-compute-57 
           5001362     short dask-wor rowan.ga  R       0:11      1 ceres18-compute-25 
           5001363     short dask-wor rowan.ga  R       0:11      1 ceres19-compute-62 
           5001364     short dask-wor rowan.ga  R       0:11      1 ceres19-compute-63 
           5001365     short dask-wor rowan.ga  R       0:11      1 ceres14-compute-47 
           5001360     short dask-wor rowan.ga  R       0:13      1 ceres14-compute-55 
           5001361     short dask-wor rowan.ga  R       0:13      1 ceres14-compute-64 

Preprocess, Transform, and Merge the Data

Harmonized Landsat Sentinel Data

Link to data repository: https://hls.gsfc.nasa.gov/

Workflow:

  1. Data is stored in the Zarr format with three dimensions (x,y,time).
  2. Read with xarray.
  3. Divide the data into chunks. Here we have chunked the data by: x=20 pixels, y=20 pixels, date=Entire Dataset
  4. Subset the data to only included "growing season" months.
  5. Convert the xarray object to a 2-Dimensional dataframe.

Notice that the data is not read to memory. The only information stored is the "task graph" and metadata about the final results.

In [5]:
#Read the data with Xarray and rechunk
ndvi = xr.open_zarr('/lustre/project/geospatial_tutorials/wg_2020_ws/data/cper_hls_ndvi.zarr/').chunk({'x':20,'y':20,'date':-1})
ndvi
Out[5]:
<xarray.Dataset>
Dimensions:  (date: 701, x: 321, y: 321)
Coordinates:
  * date     (date) datetime64[ns] 2013-04-19 2013-05-05 ... 2020-06-05
  * x        (x) float64 5.176e+05 5.177e+05 5.177e+05 ... 5.272e+05 5.272e+05
  * y        (y) float64 4.524e+06 4.524e+06 4.524e+06 ... 4.515e+06 4.515e+06
Data variables:
    ndvi     (date, y, x) float64 dask.array<chunksize=(701, 20, 20), meta=np.ndarray>
In [6]:
#Select relevant months and then convert to a dataframe
ndvi_df = ndvi.sel(date=ndvi['date.month'].isin([5,6,7,8,9])).to_dask_dataframe()
#Only include reasonable values (.1 < NDVI < 1.0) in the analysis
ndvi_df = ndvi_df[(ndvi_df.ndvi>.1)&(ndvi_df.ndvi<1.)]
print('There are '+f'{len(ndvi_df):,}'+' NDVI observations.')
ndvi_df
There are 16,233,492 NDVI observations.
Out[6]:
Dask DataFrame Structure:
date x y ndvi
npartitions=281
0 datetime64[ns] float64 float64 float64
103041 ... ... ... ...
... ... ... ... ...
28851480 ... ... ... ...
28954520 ... ... ... ...
Dask Name: getitem, 13451 tasks

Polaris Soil Hydraulic Data

Paper Describing the Data: https://agupubs.onlinelibrary.wiley.com/doi/abs/10.1029/2018WR022797
Data Repository Source: http://hydrology.cee.duke.edu/POLARIS/PROPERTIES/v1.0/

Workflow:

  1. Data is stored in the Zarr format with two dimensions (x,y) and includes 13 variables at 6 depths (78 total). Read with xarray.
  2. Interpolate the data to the same grid as the HLS NDVI data.
  3. Convert the xarray object to a 2-Dimensional Pandas dataframe.
In [7]:
soil = xr.open_zarr('/lustre/project/geospatial_tutorials/wg_2020_ws/data/polaris_soils.zarr/')
#Interpolate to the HLS NDVI grid
soil_df = soil.interp(x=ndvi.x,y=ndvi.y,method='linear').squeeze().to_dataframe().reset_index()
soil_df
Out[7]:
x y alpha_mean_0_5 alpha_mean_100_200 alpha_mean_15_30 alpha_mean_30_60 alpha_mean_5_15 alpha_mean_60_100 band bd_mean_0_5 ... theta_r_mean_15_30 theta_r_mean_30_60 theta_r_mean_5_15 theta_r_mean_60_100 theta_s_mean_0_5 theta_s_mean_100_200 theta_s_mean_15_30 theta_s_mean_30_60 theta_s_mean_5_15 theta_s_mean_60_100
0 517635.0 4524345.0 -0.267512 -0.261857 -0.357152 -0.373476 -0.300172 -0.313353 1 1.358952 ... 0.066429 0.065438 0.058725 0.057867 0.487272 0.457731 0.469949 0.467202 0.477637 0.467717
1 517635.0 4524315.0 -0.353870 -0.397741 -0.435140 -0.464050 -0.388517 -0.429784 1 1.373639 ... 0.066839 0.069663 0.060983 0.061927 0.481719 0.449684 0.469488 0.463120 0.475950 0.459320
2 517635.0 4524285.0 -0.265410 -0.347196 -0.320199 -0.362147 -0.284938 -0.367521 1 1.377104 ... 0.062490 0.063163 0.059096 0.056489 0.480338 0.447308 0.459253 0.451799 0.469832 0.452801
3 517635.0 4524255.0 -0.185915 -0.208319 -0.250943 -0.278967 -0.203988 -0.260081 1 1.362455 ... 0.063173 0.060283 0.057469 0.052539 0.485866 0.453303 0.459874 0.456485 0.471385 0.461014
4 517635.0 4524225.0 -0.194945 -0.154524 -0.241702 -0.248145 -0.206497 -0.205948 1 1.377880 ... 0.059520 0.056813 0.055008 0.046228 0.480045 0.437607 0.461432 0.453434 0.471137 0.448864
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
103036 527235.0 4514865.0 -0.326156 -0.283270 -0.434785 -0.430498 -0.360301 -0.353984 1 1.287102 ... 0.091644 0.085653 0.071946 0.067868 0.514897 0.458693 0.490777 0.481042 0.504160 0.469501
103037 527235.0 4514835.0 -0.303673 -0.289413 -0.416764 -0.410695 -0.340289 -0.342081 1 1.289779 ... 0.091355 0.084758 0.071945 0.067685 0.514902 0.458483 0.487712 0.476795 0.501821 0.465729
103038 527235.0 4514805.0 -0.262427 -0.267768 -0.369375 -0.375675 -0.299791 -0.306792 1 1.303656 ... 0.085868 0.080742 0.069070 0.064427 0.512519 0.454470 0.485322 0.473971 0.499007 0.460852
103039 527235.0 4514775.0 -0.216879 -0.241102 -0.307557 -0.333831 -0.256776 -0.275581 1 1.316483 ... 0.078472 0.076242 0.064507 0.061170 0.509099 0.450149 0.481050 0.468338 0.493489 0.456221
103040 527235.0 4514745.0 -0.165478 -0.195974 -0.250293 -0.271493 -0.202663 -0.227983 1 1.332964 ... 0.073350 0.070925 0.059594 0.056485 0.504206 0.443299 0.474748 0.460958 0.485719 0.449361

103041 rows × 81 columns

PRISM Precipitation, Tempature, and Vapor Pressure Deficit Data

PRISM Data in a CSV file. Note this data was queried at a single point at the center of CPER.

Workflow:

  1. Data is stored in the csv format and includes 7 variables. Read with Pandas using:
    • Skip the 1st 10 rows (PRISM metadata)
    • Convert the time column from a generic object to a date-time object.
  2. Rename the "Date" to "date" to match HLS NDVI data.
  3. Set the "date" column as the index.
  4. Sort the data into descending.
In [8]:
df_env = pd.read_csv('/lustre/project/geospatial_tutorials/wg_2020_ws/data/PRISM_ppt_tmin_tmean_tmax_tdmean_vpdmin_vpdmax_provisional_4km_20120101_20200101_40.8269_-104.7154.csv',
                      skiprows=10,
                      infer_datetime_format=True,
                      parse_dates = ['Date']).rename(columns={'Date':'date'}).set_index('date').sort_index(ascending=False)
df_env
Out[8]:
ppt (mm) tmin (degrees C) tmean (degrees C) tmax (degrees C) tdmean (degrees C) vpdmin (hPa) vpdmax (hPa)
date
2020-01-01 0.00 -11.5 -4.4 2.8 -13.5 0.94 4.83
2019-12-31 0.00 -12.6 -7.5 -2.3 -16.1 0.77 2.85
2019-12-30 0.00 -9.4 -7.4 -5.4 -15.5 1.04 1.77
2019-12-29 1.49 -7.4 -6.7 -6.1 -10.5 0.25 0.96
2019-12-28 5.22 -5.2 -3.2 -1.2 -4.7 0.22 1.06
... ... ... ... ... ... ... ...
2012-01-05 0.00 -7.0 3.8 14.5 -9.9 0.92 14.55
2012-01-04 0.00 -6.5 3.6 13.7 -10.1 0.91 13.65
2012-01-03 0.00 -11.2 -0.9 9.4 -9.8 0.48 8.72
2012-01-02 0.00 -12.1 -5.2 1.8 -12.6 0.34 4.69
2012-01-01 0.00 -11.4 -2.0 7.3 -11.6 0.47 6.23

2923 rows × 7 columns

Transform Function to Merge NDVI, Soil, and PRISM data.

Here we develop a class to merge the three dataset. Note the most import code is in the def transform function.

In [9]:
#Costum transformer in the scikit-learn API syntax
class merge_dsets(BaseEstimator,TransformerMixin):
    def __init__(self, df_soil, df_env,lag):
        self.soil = df_soil
        self.env = df_env
        self.lag = lag
        #self.lag_interv = lag_interval
    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        df = X.copy()
        df = df.merge(self.soil, on =['x','y'])
        df_env_m = pd.DataFrame()
        for i,d in enumerate(df.date.unique()):
            df_env_temp = df_env[df_env.index<d+pd.Timedelta('1days')].resample('1W-'+d.day_name()[0:3].upper(),
                                                                                label='right').agg({'ppt (mm)':'sum',
                                                                                                    'tmean (degrees C)':'mean',
                                                                                                    'vpdmin (hPa)':'mean',
                                                                                                    'vpdmax (hPa)':'mean'}).sort_index(ascending=False).iloc[0:self.lag].reset_index().reset_index().rename(columns={'index':'week'})
            df_env_temp = df_env_temp.drop(columns='date').melt(id_vars='week')
            df_env_temp['col']='week'+df_env_temp.week.astype(str)+'_'+df_env_temp.variable.str.split(' ',expand=True).values[:,0]
            df_env_temp = df_env_temp.set_index('col').drop(columns=['week','variable']).T
            df_env_temp['date']=d
            df_env_temp = df_env_temp.set_index('date',drop=True)
            df_env_m = df_env_m.append(df_env_temp)
        df = df.merge(df_env_m,left_on='date',right_index=True)
        df['DOY'] = df.date.dt.dayofyear
        return(df.drop(columns=['date','x','y','ndvi']),df[['ndvi']])#.to_dask_array(lengths=True))

Machine Learning: XGBoost Model

The "learn" portion in the ETL pipeline.

In [10]:
from sklearn.pipeline import Pipeline
import xgboost as xgb
#from dask_ml.xgboost import XGBRegressor as dask_XGBRegressor
from dask_ml.model_selection import train_test_split
from sklearn.metrics import r2_score
from dask_ml.model_selection import GridSearchCV
from sklearn.model_selection import GridSearchCV as sk_GridSearchCV
import joblib

Hyperparameter Optimization

Shuffle and subset data to a managable size (e.g. will fit in memory when running 360 simaltaneous models). We will use a grid-search, combined with 3-fold cross validation, approach to optimize the relevant hyperparameters (see table below).

Hyperparameter Grid n
n_estimators [150, 250, 300, 350] 4
learning_rate [0.05, 0.1, 0.2, 0.3] 4
max_depth [5, 7, 9, 11] 4
colsample_bytree [.1, .2, .3] 3
gamma [.05, .1, .2] 3

A total of 1728 models (4 4 4 3 3 * 3) will be fit. The hyperparameters assocated with the best scoring model (highest R2) will be used to train the remianing data.

This search can take ~1-2 hour using 360 cores. To run the hyperparameter gridsearch cross validation, set the optimize_hyperparameter variable to True (see two cells below). If you leave as False, we will skip the hyperparameter calculatoins, and just use the hyperparameter values previously calculated.

In [11]:
X_train_hyp, X = train_test_split(ndvi_df,
                                  test_size=0.97,
                                  shuffle=True,
                                  random_state=34)
X_train_hyp,Y_train_hyp = dask.compute(*merge_dsets(df_soil=soil_df,
                                      df_env=df_env,
                                      lag=26).transform(X_train_hyp))
X_train_hyp
Out[11]:
alpha_mean_0_5 alpha_mean_100_200 alpha_mean_15_30 alpha_mean_30_60 alpha_mean_5_15 alpha_mean_60_100 band bd_mean_0_5 bd_mean_100_200 bd_mean_15_30 ... week17_vpdmax week18_vpdmax week19_vpdmax week20_vpdmax week21_vpdmax week22_vpdmax week23_vpdmax week24_vpdmax week25_vpdmax DOY
0 -0.057543 -0.123020 -0.130316 -0.168642 -0.089308 -0.153859 1 1.372200 1.499830 1.438692 ... 4.095714 3.082857 7.395714 6.590000 10.048571 10.514286 13.182857 6.598571 11.681429 125
1 -0.157332 -0.312549 -0.247969 -0.317925 -0.196510 -0.355238 1 1.290691 1.339451 1.407826 ... 4.095714 3.082857 7.395714 6.590000 10.048571 10.514286 13.182857 6.598571 11.681429 125
2 -0.200689 -0.217211 -0.255901 -0.281028 -0.207795 -0.245101 1 1.464260 1.501998 1.495993 ... 4.095714 3.082857 7.395714 6.590000 10.048571 10.514286 13.182857 6.598571 11.681429 125
3 -0.318979 -0.097333 -0.403365 -0.392521 -0.340738 -0.226666 1 1.360848 1.532734 1.378598 ... 4.095714 3.082857 7.395714 6.590000 10.048571 10.514286 13.182857 6.598571 11.681429 125
4 -0.277503 -0.181701 -0.353242 -0.364405 -0.280755 -0.266992 1 1.417268 1.500239 1.425412 ... 4.095714 3.082857 7.395714 6.590000 10.048571 10.514286 13.182857 6.598571 11.681429 125
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
3110 -0.378027 -0.111474 -0.446061 -0.419167 -0.400648 -0.231039 1 1.327470 1.509697 1.341241 ... 42.280000 34.842857 36.871429 36.798571 38.205714 39.044286 32.791429 39.895714 28.351429 157
3111 -0.206608 -0.230582 -0.271564 -0.294051 -0.220352 -0.260066 1 1.451308 1.489187 1.486078 ... 42.280000 34.842857 36.871429 36.798571 38.205714 39.044286 32.791429 39.895714 28.351429 157
3112 -0.256232 -0.139178 -0.312052 -0.324134 -0.269426 -0.222845 1 1.371335 1.492979 1.393018 ... 42.280000 34.842857 36.871429 36.798571 38.205714 39.044286 32.791429 39.895714 28.351429 157
3113 -0.213494 -0.233495 -0.304130 -0.321414 -0.247285 -0.257843 1 1.315532 1.457057 1.377256 ... 42.280000 34.842857 36.871429 36.798571 38.205714 39.044286 32.791429 39.895714 28.351429 157
3114 -0.166889 -0.159117 -0.173296 -0.221773 -0.155919 -0.169908 1 1.339214 1.486455 1.360649 ... 42.280000 34.842857 36.871429 36.798571 38.205714 39.044286 32.791429 39.895714 28.351429 157

486990 rows × 184 columns

In [12]:
# Set to True if you want to run the Gridsearch. This can take >1.5 hrs. Therefore, 
# if set to false, the results (best hyperparameters) hardcoded from a previous run 
# of the model
optimize_hyperparameters = False
In [13]:
if optimize_hyperparameters:
    #Define the grid - space
    param_dist = {'n_estimators': [150,250,300,350],
        'learning_rate': [0.05, 0.1, 0.2, 0.3],
        'max_depth': [5, 7, 9, 11],
        'colsample_bytree': [.1, .2, .3],
        'gamma': [.05, .1, .2]}
    #Define the XGBoost model
    reg = xgb.XGBRegressor(n_jobs=1,verbosity=3)
    #Setup the GridsearchCV function
    gs = GridSearchCV(reg,param_dist,cv=3,scheduler=cl,refit=False,cache_cv=False)
    #Fit all the models
    gs.fit(X_train_hyp.values,Y_train_hyp.values)
    #Get the best fitting parameters
    df_params = pd.DataFrame(gs.cv_results_)
    best_params = df_params[df_params.mean_test_score==df_params.mean_test_score.max()]
    best_params = best_params.params.values[0]
    print(best_params)
else:
    #Best fit parameters from previous run
    best_params = {'colsample_bytree': 0.2,
                   'gamma': 0.1,
                   'learning_rate': 0.05,
                   'max_depth': 7,
                   'n_estimators': 350}
    print('Using the previously calculated parameters, which are:')
    print(best_params)
Using the previously calculated parameters, which are:
{'colsample_bytree': 0.2, 'gamma': 0.1, 'learning_rate': 0.05, 'max_depth': 7, 'n_estimators': 350}

Distributed XGBoost Model

  • Shuffle and split data into "training" (80%) and "testing" (20%). Leave as dask dataframes (data needs to be distributed across all workers), so we will call dask.persist to trigger the calculation (rather than dask.compute).
  • Train XGBoost model using the training data.
  • Model Validation / Accuracy (r2) with "testing" data
In [14]:
# Split the data
X_train, X_test = train_test_split(X,
                                   test_size=0.2,
                                   shuffle=True)
#Merge the weather/soil data and persist the data across the cluster
[X_train,Y_train],[X_test,Y_test] = dask.persist(*[merge_dsets(df_soil=soil_df,df_env=df_env,lag=26).transform(X_train),
                                               merge_dsets(df_soil=soil_df,df_env=df_env,lag=26).transform(X_test)])
wait([X_train,X_test,Y_train,Y_test])
X_train
Out[14]:
Dask DataFrame Structure:
alpha_mean_0_5 alpha_mean_100_200 alpha_mean_15_30 alpha_mean_30_60 alpha_mean_5_15 alpha_mean_60_100 band bd_mean_0_5 bd_mean_100_200 bd_mean_15_30 bd_mean_30_60 bd_mean_5_15 bd_mean_60_100 clay_mean_0_5 clay_mean_100_200 clay_mean_15_30 clay_mean_30_60 clay_mean_5_15 clay_mean_60_100 hb_mean_0_5 hb_mean_100_200 hb_mean_15_30 hb_mean_30_60 hb_mean_5_15 hb_mean_60_100 ksat_mean_0_5 ksat_mean_100_200 ksat_mean_15_30 ksat_mean_30_60 ksat_mean_5_15 ksat_mean_60_100 lambda_mean_0_5 lambda_mean_100_200 lambda_mean_15_30 lambda_mean_30_60 lambda_mean_5_15 lambda_mean_60_100 n_mean_0_5 n_mean_100_200 n_mean_15_30 n_mean_30_60 n_mean_5_15 n_mean_60_100 om_mean_0_5 om_mean_100_200 om_mean_15_30 om_mean_30_60 om_mean_5_15 om_mean_60_100 ph_mean_0_5 ph_mean_100_200 ph_mean_15_30 ph_mean_30_60 ph_mean_5_15 ph_mean_60_100 sand_mean_0_5 sand_mean_100_200 sand_mean_15_30 sand_mean_30_60 sand_mean_5_15 sand_mean_60_100 silt_mean_0_5 silt_mean_100_200 silt_mean_15_30 silt_mean_30_60 silt_mean_5_15 silt_mean_60_100 theta_r_mean_0_5 theta_r_mean_100_200 theta_r_mean_15_30 theta_r_mean_30_60 theta_r_mean_5_15 theta_r_mean_60_100 theta_s_mean_0_5 theta_s_mean_100_200 theta_s_mean_15_30 theta_s_mean_30_60 theta_s_mean_5_15 theta_s_mean_60_100 week0_ppt week1_ppt week2_ppt week3_ppt week4_ppt week5_ppt week6_ppt week7_ppt week8_ppt week9_ppt week10_ppt week11_ppt week12_ppt week13_ppt week14_ppt week15_ppt week16_ppt week17_ppt week18_ppt week19_ppt week20_ppt week21_ppt week22_ppt week23_ppt week24_ppt week25_ppt week0_tmean week1_tmean week2_tmean week3_tmean week4_tmean week5_tmean week6_tmean week7_tmean week8_tmean week9_tmean week10_tmean week11_tmean week12_tmean week13_tmean week14_tmean week15_tmean week16_tmean week17_tmean week18_tmean week19_tmean week20_tmean week21_tmean week22_tmean week23_tmean week24_tmean week25_tmean week0_vpdmin week1_vpdmin week2_vpdmin week3_vpdmin week4_vpdmin week5_vpdmin week6_vpdmin week7_vpdmin week8_vpdmin week9_vpdmin week10_vpdmin week11_vpdmin week12_vpdmin week13_vpdmin week14_vpdmin week15_vpdmin week16_vpdmin week17_vpdmin week18_vpdmin week19_vpdmin week20_vpdmin week21_vpdmin week22_vpdmin week23_vpdmin week24_vpdmin week25_vpdmin week0_vpdmax week1_vpdmax week2_vpdmax week3_vpdmax week4_vpdmax week5_vpdmax week6_vpdmax week7_vpdmax week8_vpdmax week9_vpdmax week10_vpdmax week11_vpdmax week12_vpdmax week13_vpdmax week14_vpdmax week15_vpdmax week16_vpdmax week17_vpdmax week18_vpdmax week19_vpdmax week20_vpdmax week21_vpdmax week22_vpdmax week23_vpdmax week24_vpdmax week25_vpdmax DOY
npartitions=281
float64 float64 float64 float64 float64 float64 int64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 float64 int64
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Dask Name: drop_by_shallow_copy, 281 tasks
In [15]:
#Setup the Distributed XGBoost model and train it on the "training" data
dtrain = xgb.dask.DaskDMatrix(cl, X_train, Y_train)
reg_b = xgb.dask.train(cl,
                       best_params,
                       dtrain,
                       num_boost_round=100,
                       evals=[(dtrain, 'train')])
print(reg_b)
{'booster': <xgboost.core.Booster object at 0x2afba1c8bc90>, 'history': {'train': {'rmse': [0.194975, 0.185721, 0.176977, 0.168692, 0.160853, 0.153405, 0.146336, 0.139718, 0.133419, 0.12742, 0.12183, 0.116479, 0.111427, 0.106723, 0.102239, 0.098028, 0.094032, 0.090281, 0.086761, 0.083473, 0.08037, 0.077431, 0.074695, 0.072133, 0.069717, 0.067483, 0.065383, 0.06343, 0.061609, 0.059914, 0.058332, 0.056865, 0.05551, 0.054234, 0.053065, 0.051984, 0.050988, 0.050069, 0.049222, 0.048449, 0.04773, 0.047075, 0.04646, 0.04589, 0.045383, 0.044889, 0.044433, 0.044048, 0.043694, 0.043365, 0.043058, 0.042778, 0.042514, 0.042262, 0.04203, 0.041833, 0.041631, 0.041468, 0.041305, 0.041138, 0.040998, 0.040875, 0.040727, 0.040621, 0.040498, 0.040396, 0.040308, 0.040233, 0.040159, 0.040072, 0.040013, 0.039958, 0.039908, 0.039853, 0.0398, 0.039744, 0.039705, 0.039641, 0.039592, 0.039544, 0.039507, 0.039466, 0.039414, 0.039377, 0.039351, 0.039327, 0.039292, 0.039265, 0.039231, 0.039199, 0.039169, 0.039137, 0.039108, 0.039091, 0.039058, 0.039045, 0.03903, 0.039012, 0.038986, 0.038964]}}}
In [16]:
#Get the R2 results for the testing data
dtest = xgb.dask.DaskDMatrix(cl, X_test)
pred = xgb.dask.predict(cl, reg_b['booster'], dtest)
reg_r2 = r2_score(Y_test.ndvi.compute().values,pred)
print("The overall R2 is: "+str(reg_r2))
The overall R2 is: 0.7823793516167727
In [17]:
#Big Data Plotting Libraries
import datashader as ds
import holoviews as hv
from holoviews.operation.datashader import datashade, shade, dynspread, rasterize
hv.extension('bokeh')
In [18]:
#Plot the results
Y_plotting = Y_test.compute()
Y_plotting['pred']=pred.compute()
Y_plotting
Out[18]:
ndvi pred
0 0.179668 0.200039
1 0.192408 0.202908
2 0.207891 0.205593
3 0.199034 0.192482
4 0.218768 0.193988
... ... ...
19998 0.260389 0.261929
19999 0.198830 0.260994
20000 0.240795 0.263959
20001 0.261963 0.275211
20002 0.265043 0.266167

3149450 rows × 2 columns

In [19]:
#To plot all the points, we need to rasterize the data (aka a 2d histogram)
pts_res = hv.Points(Y_plotting.values,label="")
rasterize(pts_res).redim.range(Count=(10, 2000)).opts(cmap='viridis',
                                                      tools=['hover'],
                                                      xlim=(0.15,.6),
                                                      ylim=(0.15,.6),
                                                      clipping_colors={'min': 'transparent'},
                                                      xlabel='HLS NDVI',
                                                      ylabel='Predicted NDVI',
                                                      logz=True)
/opt/conda/envs/py_geo/lib/python3.7/site-packages/holoviews/plotting/util.py:685: MatplotlibDeprecationWarning: The global colormaps dictionary is no longer considered public API.
  [cmap for cmap in cm.cmap_d if not
Out[19]:
In [20]:
#Standard approaches can render different results
#Show the top 20 most import features as defined by the XGBoost model
xgb.plot_importance(reg_b['booster'],max_num_features=20,importance_type='weight')
xgb.plot_importance(reg_b['booster'],max_num_features=20,importance_type='gain')
xgb.plot_importance(reg_b['booster'],max_num_features=20,importance_type='cover')
Out[20]:
<AxesSubplot:title={'center':'Feature importance'}, xlabel='F score', ylabel='Features'>
In [21]:
#Import the SHAP libraries
import shap
import matplotlib.pyplot as plt
shap.initjs()
In [22]:
#Split data into better manageable slices
X_shap, _= train_test_split(X_test,test_size=0.95,shuffle=True)

Apply SHAP Model: Below we split the data by month, and examine the effect of the features on the model (by month).

In [23]:
#Day of Year for each month
months = {'May':[121,152],
          'June':[153,182],
          'July':[183,2013],
          'August':[214,244],
          'September':[245,274]}

#Function for calculating SHAP values. We will map this function across the data on the cluster
def calc_shap_vals(block,explainer):
    if len(block)>0:
        block_vals = explainer.shap_values(block)
        return(block_vals)
    else:
        return(np.empty((0,184)))

#Loop over each month and create plot
explainer = shap.TreeExplainer(reg_b['booster'])
for k in months.keys():
    print(k)
    start = months[k][0]
    end = months[k][1]
    #Select only the data in the month
    X_shap1 = X_shap[(X_shap.DOY>=start)&(X_shap.DOY<=end)].repartition(npartitions=9).persist()
    wait(X_shap1)
    #Compute the SHAP values
    shap_vals = X_shap1.to_dask_array(lengths=True).map_blocks(calc_shap_vals,explainer=explainer,dtype='float32').compute()
    #Show the SHAP summary plots for each month
    print('Using an array of size:' +str(shap_vals.shape))
    plt.title(k)
    shap.summary_plot(shap_vals, X_shap1.compute(),max_display=20,title=k)
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
May
Using an array of size:(23440, 184)
June
Using an array of size:(37888, 184)
July
Using an array of size:(95676, 184)
August
Using an array of size:(33453, 184)
September
Using an array of size:(33573, 184)
In [24]:
shap_vals = X_shap.to_dask_array(lengths=True).map_blocks(calc_shap_vals,explainer=explainer,dtype='float32').compute()
shap_vals = shap_vals[~np.isnan(shap_vals).any(axis=1)]
shap.dependence_plot("week0_tmean", shap_vals, X_shap.compute(),interaction_index='DOY')
shap.dependence_plot("week0_ppt", shap_vals, X_shap.compute(),interaction_index='DOY')
shap.dependence_plot("week4_vpdmax", shap_vals, X_shap.compute(),interaction_index='DOY')