## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(elixir)

## ----eval = TRUE--------------------------------------------------------------
y <- c(x = 1, y = 1)
times <- 0:100
parms <- c(alpha = 1/6, beta = 1/3, gamma = 0.25, delta = 0.25) 

func <- function(t, y, parms)
{
    with(as.list(c(y, parms)), { 
        dx <- alpha * x - beta * x * y 
        dy <- -gamma * y + delta * x * y 
        return (list(c(dx, dy)))
    })
}

# Run this with:
# sol <- deSolve::ode(y, times, func, parms) 
# matplot(sol[, 1], sol[, -1], type = "l")

## -----------------------------------------------------------------------------
system <- quote({
    t_end = 100

    x(0) = 1
    y(0) = 1

    dx/dt = alpha * x - beta * x * y
    dy/dt = -gamma * y + delta * x * y

    alpha = 1/6
    beta = 1/3
    gamma = 0.25
    delta = 0.25
})

## -----------------------------------------------------------------------------
expr_match(system, { t_end = .X })

## ----eval = FALSE-------------------------------------------------------------
# # neither of these will work
# expr_match(system, quote(t_end = .X))
# expr_match(system, rlang::expr(t_end = .X))
# 
# # instead you would have to do something like this:
# expr_match(system, quote((t_end = .X))[[2]])
# expr_match(system, rlang::expr((t_end = .X))[[2]])
# 
# # This works because the expression (t_end = .X) is a call, which is list-like
# # with two elements:
# # [[1]] is the symbol `(`, and [[2]] is the call t_end = .X.

## -----------------------------------------------------------------------------
expr_extract(system, { t_end = .X }, "X")

## -----------------------------------------------------------------------------
expr_extract(system, { t_end = .X }, "X", n = 1)

## -----------------------------------------------------------------------------
if (expr_count(system, { t_end = .X }) != 1) {
    stop("Need exactly one specification of end time.")
}

## -----------------------------------------------------------------------------
times <- 0:expr_extract(system, { t_end = .X }, "X")[[1]]
times

## -----------------------------------------------------------------------------
expr_match(system, { .X(0) = .V })

## -----------------------------------------------------------------------------
expr_extract(system, { .X(0) = .V }, "X")
expr_extract(system, { .X(0) = .V }, "V")

y <- as.numeric(expr_extract(system, { .X(0) = .V }, "V"))
names(y) <- as.character(expr_extract(system, { .X(0) = .V }, "X"))
y

## -----------------------------------------------------------------------------
expr_match(system, { .P = .X })

## -----------------------------------------------------------------------------
expr_match(system, { .P = ..X })

## -----------------------------------------------------------------------------
expr_match(system, { `.P|P != "t_end"` = ..X })

## -----------------------------------------------------------------------------
parms <- expr_extract(system, { `.P|P != "t_end"` = ..X }, "X")
parms <- sapply(parms, eval)
names(parms) <- as.character(expr_extract(system, { `.P|P != "t_end"` = ..X }, "P"))

## -----------------------------------------------------------------------------
expr_match(system, { `.A:name|substr(A, 1, 1) == "d"`/dt = ..X })

## -----------------------------------------------------------------------------
statements <- expr_extract(system, { `.A:name|substr(A, 1, 1) == "d"`/dt = ..X })
statements

## -----------------------------------------------------------------------------
R_statements <- expr_replace(statements,
    { `.A:name|substr(A, 1, 1) == "d"`/dt = ..X },
    { .A <- ..X })
R_statements

## -----------------------------------------------------------------------------
derivatives <- expr_replace(R_statements, { .D <- ..X }, { .D })
derivatives

## -----------------------------------------------------------------------------
func <- eval(rlang::expr(
    function(t, y, parms)
    {
        with(as.list(c(y, parms)), {
            !!!R_statements
            return (list(c(!!!derivatives)))
        })
    }
))

## -----------------------------------------------------------------------------
run_ode <- function(system)
{
    # Get times
    if (expr_count(system, { t_end = .X }) != 1) {
        stop("Need exactly one specification of end time.")
    }
    times <- 0:expr_extract(system, { t_end = .X }, "X")[[1]]
    
    # Get initial state
    y <- as.numeric(expr_extract(system, { .X(0) = .V }, "V"))
    names(y) <- as.character(expr_extract(system, { .X(0) = .V }, "X"))
    
    # Get parameters
    parms <- expr_extract(system, { `.P|P != "t_end"` = ..X }, "X")
    parms <- sapply(parms, eval)
    names(parms) <- as.character(expr_extract(system, { `.P|P != "t_end"` = ..X }, "P"))
    
    # Get statements
    statements <- expr_extract(system, { `.A:name|substr(A, 1, 1) == "d"`/dt = ..X })
    R_statements <- expr_replace(statements,
        { `.A:name|substr(A, 1, 1) == "d"`/dt = ..X },
        { .A <- ..X })
    derivatives <- expr_replace(R_statements, { .D <- ..X }, { .D })
    
    func <- eval(rlang::expr(
        function(t, y, parms)
        {
            with(as.list(c(y, parms)), {
                !!!R_statements
                return (list(c(!!!derivatives)))
            })
        }
    ))
    
    # uncomment if deSolve is available:
    # sol <- deSolve::ode(y, times, func, parms) 
    # matplot(sol[, 1], sol[, -1], type = "l")
}

system <- quote({
    t_end = 100

    x(0) = 1
    y(0) = 1

    dx/dt = alpha * x - beta * x * y
    dy/dt = -gamma * y + delta * x * y

    alpha = 1/6
    beta = 1/3
    gamma = 0.25
    delta = 0.25
})

run_ode(system)

