diff --git a/src/gluonts/ext/r_forecast/R/univariate_forecast_methods.R b/src/gluonts/ext/r_forecast/R/univariate_forecast_methods.R index a8a9efda09..651a644abb 100644 --- a/src/gluonts/ext/r_forecast/R/univariate_forecast_methods.R +++ b/src/gluonts/ext/r_forecast/R/univariate_forecast_methods.R @@ -48,6 +48,49 @@ arima <- function(ts, params) { fourier.arima.xreg <- function(ts, params, xreg_in, xreg_out){ + + if (missing(xreg_in)){ + fourier.arima(ts, params) + } else { + + fourier.frequency.low.periods <- 4 + fourier.ratio.threshold.low.periods <- 18 + fourier.frequency.high.periods <- 52 + fourier.ratio.threshold.high.periods <- 2 + fourier.order <- 4 + + period <- frequency(ts) + len_ts <- length(ts) + fourier_ratio <- len_ts / period + + fourier <- FALSE + + if ((period > fourier.frequency.low.periods + && fourier_ratio > fourier.ratio.threshold.low.periods) + || (period >= fourier.frequency.high.periods + && fourier_ratio > fourier.ratio.threshold.high.periods)) { + # When the period is high, auto.arima becomes unstable + # per Rob's suggestion, we use Fourier series instead + # cf. https://robjhyndman.com/hyndsight/longseasonality/ + fourier <- TRUE + } + + if (fourier == TRUE) { + K <- min(fourier.order, floor(frequency(ts) / 2)) + seasonal <- FALSE + xreg <- forecast::fourier(ts, K=K) + xreg_in <- as.matrix(xreg_in, xreg) + model <- forecast::auto.arima(ts, seasonal = seasonal, xreg = xreg_in, trace=TRUE) + + xreg <- forecast::fourier(ts, K=K, h=params$prediction_length) + xreg_out <- as.matrix(xreg_out, xreg) + + handleModel(model, params, xreg_out) + } else { + model <- forecast::auto.arima(ts, xreg = xreg_in, trace=TRUE) + handleModel(model, params, xreg_out) + } + fourier.frequency.low.periods <- 4 fourier.ratio.threshold.low.periods <- 18 fourier.frequency.high.periods <- 52 diff --git a/src/gluonts/ext/r_forecast/_univariate_predictor.py b/src/gluonts/ext/r_forecast/_univariate_predictor.py index aaf259e326..8925983adf 100644 --- a/src/gluonts/ext/r_forecast/_univariate_predictor.py +++ b/src/gluonts/ext/r_forecast/_univariate_predictor.py @@ -145,7 +145,6 @@ def _get_r_forecast(self, data: Dict) -> Dict: import rpy2.robjects.numpy2ri rpy2.robjects.numpy2ri.activate() - data["feat_dynamic_real"] = np.transpose(data["feat_dynamic_real"]) nrow, ncol = data["feat_dynamic_real"].shape xreg_in = self._robjects.r.matrix( diff --git a/test/ext/r_forecast/test_r_univariate_predictor.py b/test/ext/r_forecast/test_r_univariate_predictor.py index 4eefa1bbb0..eb5b59115c 100644 --- a/test/ext/r_forecast/test_r_univariate_predictor.py +++ b/test/ext/r_forecast/test_r_univariate_predictor.py @@ -48,10 +48,6 @@ def test_forecasts(method_name): "MLP currently does not work because " "the `neuralnet` package is not yet updated with a known bug fix in ` bips-hb/neuralnet`" ) - if method_name == "fourier.arima.xreg": - pytest.xfail( - "Method `fourier.arima.xreg` does not work because of a known issue." - ) dataset = datasets.get_dataset("constant")