Skip to content

Modules

Top-level package for jax_russell.

CRRBinomialTree

Bases: BinomialTree

Cox Ross Rubinstein binomial tree.

__call__() is tested against example in Haug.

Source code in jax_russell/trees.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
class CRRBinomialTree(BinomialTree):
    """Cox Ross Rubinstein binomial tree.

    `__call__()` is tested against example in Haug.
    """  # noqa

    @partial(jax.jit, static_argnums=0)
    @typeguard.typechecked
    def value(
        self,
        start_price: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        volatility: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        time_to_expiration: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        risk_free_rate: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        cost_of_carry: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        is_call: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        strike: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    ) -> jaxtyping.Float[jaxtyping.Array, "*#contracts"]:
        """Calculate values for option contracts.

        Returns:
            jnp.array: contract values
        """
        up_factors, down_factors = self._calc_factors(
            volatility,
            time_to_expiration,
        )
        end_probabilities = self._calc_end_probabilities(
            up_factors,
            down_factors,
            time_to_expiration,
            cost_of_carry,
        )

        end_underlying_values = self._calc_end_values(
            start_price,
            up_factors,
            down_factors,
        )

        args = self._transform_args_for_discounter(
            time_to_expiration,
            risk_free_rate,
            cost_of_carry,
            is_call,
            strike,
            up_factors,
            down_factors,
            end_probabilities,
            end_underlying_values,
        )
        return self.discounter(*args)

    def _calc_factors(
        self,
        volatility: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        time_to_expiration: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    ) -> Tuple[jaxtyping.Float[jaxtyping.Array, "*#contracts"], jaxtyping.Float[jaxtyping.Array, "*#contracts"]]:
        """Calculates the factor by which an asset price is multiplied for upward, downward movement at a step.

        Returns:
            jnp.array, jnp.array: factors on upward move, factors on downward move
        """
        scaled_volatility = volatility * jnp.sqrt(time_to_expiration / self.steps)
        return jnp.exp(scaled_volatility), jnp.exp(-scaled_volatility)

    @partial(jax.jit, static_argnums=0)
    @typeguard.typechecked
    def _calc_end_probabilities(
        self,
        up_factors: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        down_factors: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        time_to_expiration: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        cost_of_carry: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    ) -> jaxtyping.Float[jaxtyping.Array, "*#contracts n"]:  # noqa
        """Calculate the probability of arriving at every end node in the tree.

        Returns:
            jnp.Array: Array with probabiliities in the last dimension, size `self.steps + 1`
        """
        p_up = self._calc_transition_up_probabilities(
            up_factors,
            down_factors,
            time_to_expiration,
            cost_of_carry,
        )
        up_steps = jnp.arange(self.steps + 1)
        end_probabilities = jnp.power(jnp.expand_dims(p_up, -1), up_steps) * jnp.power(
            1 - jnp.expand_dims(p_up, -1), self.steps - up_steps
        )
        if self.option_type == "european":
            end_probabilities *= comb(self.steps, up_steps)

        return end_probabilities

__call__(*args, **kwargs)

Value arrays of options.

By default, __call__ checks its arguments against value() and passes them through.

Returns:

Type Description

jnp.array: option values

Source code in jax_russell/base.py
65
66
67
68
69
70
71
72
73
74
75
@partial(jax.jit, static_argnums=0)
def __call__(self, *args, **kwargs):
    """Value arrays of options.

    By default, `__call__` checks its arguments against `value()` and passes them through.

    Returns:
        jnp.array: option values
    """
    inspect.signature(self.value).bind(*args, **kwargs)
    return self.value(*jnp.broadcast_arrays(*args), **kwargs)

__init__(steps, option_type, discounter=None)

Parameters:

Name Type Description Default
steps int

The number of time steps in the binomial tree.

required
Source code in jax_russell/trees.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def __init__(
    self,
    steps: int,
    option_type: str,
    discounter: Union[AmericanDiscounter, EuropeanDiscounter, None] = None,
) -> None:
    """

    Args:
        steps (int): The number of time steps in the binomial tree.
    """  # noqa
    assert option_type in [
        "european",
        "american",
    ], f"option_type must be one of `european` or `american` got {option_type}"
    assert (
        discounter is None
        or getattr(discounter, "steps", None) is None
        or getattr(discounter, "steps", None) == steps
    )
    self.steps = steps
    self.option_type = option_type
    self.discounter = (
        discounter
        if discounter is not None
        else AmericanDiscounter(steps)
        if option_type == 'american'
        else EuropeanDiscounter()
    )

first_order(*args, **kwargs)

Automatically calculate first-order greeks.

Returns:

Name Type Description
_type_

description

Source code in jax_russell/base.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@partial(jax.jit, static_argnums=0)
def first_order(self, *args, **kwargs):
    """Automatically calculate first-order greeks.

    Returns:
        _type_: _description_
    """
    inspect.signature(self).bind(*args, **kwargs)
    return jnp.hstack(
        jax.jacfwd(
            self,
            range(len(args)) if self.argnums is None else self.argnums,
        )(*args, **kwargs)
    )

second_order(*args, **kwargs)

Automatically calculate second-order greeks.

Returns:

Name Type Description
_type_

description

Source code in jax_russell/base.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@partial(jax.jit, static_argnums=0)
def second_order(self, *args, **kwargs):
    """Automatically calculate second-order greeks.

    Returns:
        _type_: _description_
    """
    inspect.signature(self).bind(*args, **kwargs)
    return jnp.concatenate(
        jax.jacfwd(
            self.first_order,
            range(len(args)) if self.argnums is None else self.argnums,
            # self.argnums,
        )(*args, **kwargs),
        axis=-1,
    )

solve_implied(expected_option_values, init_params, **kwargs)

Solve for an implied value, usually volatility.

This method allows the flexibility to solve for any combination of values used in the valuation method's __call__() signature. For example, passing {"risk_free_rate": jnp.array([0.05]),"volatility":jnp.array([.5])} will solve for the implied values of both volatility and the risk free rate.

Parameters:

Name Type Description Default
expected_option_values array

option values, typically observed market prices

required
init_params dict[array]

initial guesses to begin solve optimization

required

Returns:

Type Description

params, state: the parameters and state returned by a jaxopt optimizer run()

Source code in jax_russell/base.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def solve_implied(
    self,
    expected_option_values,
    init_params,
    **kwargs,
):
    """Solve for an implied value, usually volatility.

    This method allows the flexibility to solve for any combination of values used in the valuation method's `__call__()` signature.
    For example, passing `{"risk_free_rate": jnp.array([0.05]),"volatility":jnp.array([.5])}` will solve for the implied values of both volatility and the risk free rate.

    Args:
        expected_option_values jnp.array: option values, typically observed market prices
        init_params dict[jnp.array]: initial guesses to begin solve optimization

    Returns:
        params, state: the parameters and state returned by a `jaxopt` optimizer `run()`
    """  # noqa: E501
    signature = inspect.signature(self.__call__)
    # inspect signature using bind to make sure all args have been passed
    signature.bind(**{**init_params, **kwargs})

    @jax.jit
    def objective(params, expected, kwargs):
        bound_arguments = signature.bind(**{**params, **kwargs})
        residuals = expected - self(*bound_arguments.args, **bound_arguments.kwargs)
        return jnp.mean(residuals**2)

    solver = jaxopt.BFGS(
        objective,
    )
    res = solver.run(
        init_params,
        expected=expected_option_values,
        kwargs=kwargs,
    )
    return res

value(start_price, volatility, time_to_expiration, risk_free_rate, cost_of_carry, is_call, strike)

Calculate values for option contracts.

Returns:

Type Description
Float[Array, '*#contracts']

jnp.array: contract values

Source code in jax_russell/trees.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
@partial(jax.jit, static_argnums=0)
@typeguard.typechecked
def value(
    self,
    start_price: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    volatility: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    time_to_expiration: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    risk_free_rate: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    cost_of_carry: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    is_call: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    strike: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
) -> jaxtyping.Float[jaxtyping.Array, "*#contracts"]:
    """Calculate values for option contracts.

    Returns:
        jnp.array: contract values
    """
    up_factors, down_factors = self._calc_factors(
        volatility,
        time_to_expiration,
    )
    end_probabilities = self._calc_end_probabilities(
        up_factors,
        down_factors,
        time_to_expiration,
        cost_of_carry,
    )

    end_underlying_values = self._calc_end_values(
        start_price,
        up_factors,
        down_factors,
    )

    args = self._transform_args_for_discounter(
        time_to_expiration,
        risk_free_rate,
        cost_of_carry,
        is_call,
        strike,
        up_factors,
        down_factors,
        end_probabilities,
        end_underlying_values,
    )
    return self.discounter(*args)

ExerciseValuer

Bases: ABC

Abstract class for Callables that implement, or approximate, the max(exercise value, 0) operation.

This is applied in the intermediate steps of a binomial tree.

Source code in jax_russell/trees.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class ExerciseValuer(abc.ABC):
    """Abstract class for Callables that implement, or approximate, the max(exercise value, 0) operation.

    This is applied in the intermediate steps of a binomial tree.
    """

    @typeguard.typechecked
    def __call__(
        self,
        underlying_values: jaxtyping.Float[jaxtyping.Array, "*#contracts n"],
        strike: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        is_call: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    ) -> jaxtyping.Float[jaxtyping.Array, "*#contracts"]:
        """Calculate or approximate the value of exercising an option.

        Args:
            underlying_values (jaxtyping.Float[jaxtyping.Array, "#contracts n"]): value of the underlying asset
            strike (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): option strike prices
            is_call (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): whether each option is a call (1.0) or put (0.0)

        Returns:
            jaxtyping.Float[jaxtyping.Array, "*#contracts"]: Exercise values.
        """
        return self.adjust(
            self._calc_unadjusted_value(
                underlying_values,
                strike,
                is_call,
            )
        )

    @typeguard.typechecked
    def _calc_unadjusted_value(
        self,
        underlying_values: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        strike: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
        is_call: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    ) -> jaxtyping.Float[jaxtyping.Array, "*#contracts"]:
        return (underlying_values - strike) * (2 * is_call - 1)

    @abc.abstractmethod
    def adjust(
        self,
        unadjusted_values: jaxtyping.Float[jax.Array, "*"],
    ) -> jaxtyping.Float[jaxtyping.Array, "*"]:
        """Adjust value difference to calculate an intermediate exercise value.

        This method should transform the difference between strike and underlying, i.e. `underlying - strike` for calls, `strike - underlying` for puts, to an exercise value.
        For example, a standard binomial tree uses max(unadjusted_values, 0.0).
        """  # noqa

__call__(underlying_values, strike, is_call)

Calculate or approximate the value of exercising an option.

Parameters:

Name Type Description Default
underlying_values Float[Array, '#contracts n']

value of the underlying asset

required
strike Float[Array, '*#contracts']

option strike prices

required
is_call Float[Array, '*#contracts']

whether each option is a call (1.0) or put (0.0)

required

Returns:

Type Description
Float[Array, '*#contracts']

jaxtyping.Float[jaxtyping.Array, "*#contracts"]: Exercise values.

Source code in jax_russell/trees.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@typeguard.typechecked
def __call__(
    self,
    underlying_values: jaxtyping.Float[jaxtyping.Array, "*#contracts n"],
    strike: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    is_call: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
) -> jaxtyping.Float[jaxtyping.Array, "*#contracts"]:
    """Calculate or approximate the value of exercising an option.

    Args:
        underlying_values (jaxtyping.Float[jaxtyping.Array, "#contracts n"]): value of the underlying asset
        strike (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): option strike prices
        is_call (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): whether each option is a call (1.0) or put (0.0)

    Returns:
        jaxtyping.Float[jaxtyping.Array, "*#contracts"]: Exercise values.
    """
    return self.adjust(
        self._calc_unadjusted_value(
            underlying_values,
            strike,
            is_call,
        )
    )

adjust(unadjusted_values) abstractmethod

Adjust value difference to calculate an intermediate exercise value.

This method should transform the difference between strike and underlying, i.e. underlying - strike for calls, strike - underlying for puts, to an exercise value. For example, a standard binomial tree uses max(unadjusted_values, 0.0).

Source code in jax_russell/trees.py
 96
 97
 98
 99
100
101
102
103
104
105
@abc.abstractmethod
def adjust(
    self,
    unadjusted_values: jaxtyping.Float[jax.Array, "*"],
) -> jaxtyping.Float[jaxtyping.Array, "*"]:
    """Adjust value difference to calculate an intermediate exercise value.

    This method should transform the difference between strike and underlying, i.e. `underlying - strike` for calls, `strike - underlying` for puts, to an exercise value.
    For example, a standard binomial tree uses max(unadjusted_values, 0.0).
    """  # noqa

MaxValuer

Bases: ExerciseValuer

Implements the standard maximum operation found in intermediate steps in binomial trees.

Source code in jax_russell/trees.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class MaxValuer(ExerciseValuer):
    """Implements the standard maximum operation found in intermediate steps in binomial trees."""

    def adjust(
        self,
        unadjusted_values: jaxtyping.Float[jax.Array, "*"],
    ) -> jaxtyping.Float[jaxtyping.Array, "*"]:
        """Adjust signed strike-underlying differences by applying the max op.

        Args:
            unadjusted_values (jaxtyping.Float[jax.Array, "*"]): `underlying - strike` for calls, `strike - underlying` for puts

        Returns:
            jaxtyping.Float[jaxtyping.Array, "*"]: element-wise max(unadjusted_values, 0.0)
        """  # noqa
        return jnp.maximum(unadjusted_values, 0.0)

__call__(underlying_values, strike, is_call)

Calculate or approximate the value of exercising an option.

Parameters:

Name Type Description Default
underlying_values Float[Array, '#contracts n']

value of the underlying asset

required
strike Float[Array, '*#contracts']

option strike prices

required
is_call Float[Array, '*#contracts']

whether each option is a call (1.0) or put (0.0)

required

Returns:

Type Description
Float[Array, '*#contracts']

jaxtyping.Float[jaxtyping.Array, "*#contracts"]: Exercise values.

Source code in jax_russell/trees.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@typeguard.typechecked
def __call__(
    self,
    underlying_values: jaxtyping.Float[jaxtyping.Array, "*#contracts n"],
    strike: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    is_call: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
) -> jaxtyping.Float[jaxtyping.Array, "*#contracts"]:
    """Calculate or approximate the value of exercising an option.

    Args:
        underlying_values (jaxtyping.Float[jaxtyping.Array, "#contracts n"]): value of the underlying asset
        strike (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): option strike prices
        is_call (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): whether each option is a call (1.0) or put (0.0)

    Returns:
        jaxtyping.Float[jaxtyping.Array, "*#contracts"]: Exercise values.
    """
    return self.adjust(
        self._calc_unadjusted_value(
            underlying_values,
            strike,
            is_call,
        )
    )

adjust(unadjusted_values)

Adjust signed strike-underlying differences by applying the max op.

Parameters:

Name Type Description Default
unadjusted_values Float[Array, '*']

underlying - strike for calls, strike - underlying for puts

required

Returns:

Type Description
Float[Array, '*']

jaxtyping.Float[jaxtyping.Array, "*"]: element-wise max(unadjusted_values, 0.0)

Source code in jax_russell/trees.py
111
112
113
114
115
116
117
118
119
120
121
122
123
def adjust(
    self,
    unadjusted_values: jaxtyping.Float[jax.Array, "*"],
) -> jaxtyping.Float[jaxtyping.Array, "*"]:
    """Adjust signed strike-underlying differences by applying the max op.

    Args:
        unadjusted_values (jaxtyping.Float[jax.Array, "*"]): `underlying - strike` for calls, `strike - underlying` for puts

    Returns:
        jaxtyping.Float[jaxtyping.Array, "*"]: element-wise max(unadjusted_values, 0.0)
    """  # noqa
    return jnp.maximum(unadjusted_values, 0.0)

SoftplusValuer

Bases: ExerciseValuer

Approximate the maximum operation using a softplus function.

This Callable will return log(1 + exp(kx)) / k where k is the sharpness parameter.

Source code in jax_russell/trees.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class SoftplusValuer(ExerciseValuer):
    """Approximate the maximum operation using a softplus function.

    This Callable will return `log(1 + exp(kx)) / k` where k is the sharpness parameter.
    """

    def __init__(self, sharpness: float = 1.0) -> None:
        """

        Args:
            sharpness (float): sharpness parameter k
        """  # noqa
        super().__init__()
        self.sharpness = sharpness

    def adjust(
        self,
        unadjusted_values: jaxtyping.Float[jax.Array, "*"],
        sharpness: Union[None, float] = None,
    ) -> jaxtyping.Float[jaxtyping.Array, "*"]:
        """Adjust using the softplus function.

        Args:
            unadjusted_values: jaxtyping.Float[jax.Array, "*"]): `underlying - strike` for calls, `strike - underlying` for puts
            sharpness: If None, uses `self.sharpness`

        Returns:
            jaxtyping.Float[jaxtyping.Array, "*"]: element-wise softplus
        """  # noqa
        return jnp.logaddexp((self.sharpness if sharpness is None else sharpness) * unadjusted_values, 0.0) / (
            self.sharpness if sharpness is None else sharpness
        )

__call__(underlying_values, strike, is_call)

Calculate or approximate the value of exercising an option.

Parameters:

Name Type Description Default
underlying_values Float[Array, '#contracts n']

value of the underlying asset

required
strike Float[Array, '*#contracts']

option strike prices

required
is_call Float[Array, '*#contracts']

whether each option is a call (1.0) or put (0.0)

required

Returns:

Type Description
Float[Array, '*#contracts']

jaxtyping.Float[jaxtyping.Array, "*#contracts"]: Exercise values.

Source code in jax_russell/trees.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@typeguard.typechecked
def __call__(
    self,
    underlying_values: jaxtyping.Float[jaxtyping.Array, "*#contracts n"],
    strike: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
    is_call: jaxtyping.Float[jaxtyping.Array, "*#contracts"],
) -> jaxtyping.Float[jaxtyping.Array, "*#contracts"]:
    """Calculate or approximate the value of exercising an option.

    Args:
        underlying_values (jaxtyping.Float[jaxtyping.Array, "#contracts n"]): value of the underlying asset
        strike (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): option strike prices
        is_call (jaxtyping.Float[jaxtyping.Array, "*#contracts"]): whether each option is a call (1.0) or put (0.0)

    Returns:
        jaxtyping.Float[jaxtyping.Array, "*#contracts"]: Exercise values.
    """
    return self.adjust(
        self._calc_unadjusted_value(
            underlying_values,
            strike,
            is_call,
        )
    )

__init__(sharpness=1.0)

Parameters:

Name Type Description Default
sharpness float

sharpness parameter k

1.0
Source code in jax_russell/trees.py
132
133
134
135
136
137
138
139
def __init__(self, sharpness: float = 1.0) -> None:
    """

    Args:
        sharpness (float): sharpness parameter k
    """  # noqa
    super().__init__()
    self.sharpness = sharpness

adjust(unadjusted_values, sharpness=None)

Adjust using the softplus function.

Parameters:

Name Type Description Default
unadjusted_values Float[Array, '*']

jaxtyping.Float[jax.Array, "*"]): underlying - strike for calls, strike - underlying for puts

required
sharpness Union[None, float]

If None, uses self.sharpness

None

Returns:

Type Description
Float[Array, '*']

jaxtyping.Float[jaxtyping.Array, "*"]: element-wise softplus

Source code in jax_russell/trees.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def adjust(
    self,
    unadjusted_values: jaxtyping.Float[jax.Array, "*"],
    sharpness: Union[None, float] = None,
) -> jaxtyping.Float[jaxtyping.Array, "*"]:
    """Adjust using the softplus function.

    Args:
        unadjusted_values: jaxtyping.Float[jax.Array, "*"]): `underlying - strike` for calls, `strike - underlying` for puts
        sharpness: If None, uses `self.sharpness`

    Returns:
        jaxtyping.Float[jaxtyping.Array, "*"]: element-wise softplus
    """  # noqa
    return jnp.logaddexp((self.sharpness if sharpness is None else sharpness) * unadjusted_values, 0.0) / (
        self.sharpness if sharpness is None else sharpness
    )

StockOptionBSM

Bases: StockOptionMixin, GeneralizedBlackScholesMerten

Stock option Black Scholes Merten valuation.

Source code in jax_russell/__init__.py
40
41
42
43
class StockOptionBSM(StockOptionMixin, GeneralizedBlackScholesMerten):  # type: ignore[misc]
    """Stock option Black Scholes Merten valuation."""

    __doc__ += "" if StockOptionMixin.__doc__ is None else StockOptionMixin.__doc__

StockOptionCRRTree

Bases: StockOptionMixin, CRRBinomialTree

Stock option CRR tree.

Source code in jax_russell/__init__.py
13
14
15
16
class StockOptionCRRTree(StockOptionMixin, CRRBinomialTree):  # type: ignore[misc]
    """Stock option CRR tree."""

    __doc__ += "" if StockOptionMixin.__doc__ is None else StockOptionMixin.__doc__

StockOptionContinuousDividendCRRTree

Bases: StockOptionContinuousDividendMixin, CRRBinomialTree

Stock option CRR tree with a continuous dividend.

Source code in jax_russell/__init__.py
19
20
21
22
class StockOptionContinuousDividendCRRTree(StockOptionContinuousDividendMixin, CRRBinomialTree):  # type: ignore[misc]
    """Stock option CRR tree with a continuous dividend."""

    __doc__ += "" if StockOptionContinuousDividendMixin.__doc__ is None else StockOptionContinuousDividendMixin.__doc__

StockOptionContinuousDividendRBTree

Bases: StockOptionContinuousDividendMixin, RendlemanBartterBinomialTree

Stock option Rendleman Bartter tree with a continuous dividend.

Source code in jax_russell/__init__.py
31
32
33
34
35
36
37
class StockOptionContinuousDividendRBTree(  # type: ignore[misc]
    StockOptionContinuousDividendMixin,
    RendlemanBartterBinomialTree,
):
    """Stock option Rendleman Bartter tree with a continuous dividend."""

    __doc__ += "" if StockOptionContinuousDividendMixin.__doc__ is None else StockOptionContinuousDividendMixin.__doc__

StockOptionRBTree

Bases: StockOptionMixin, RendlemanBartterBinomialTree

Stock option Rendleman Bartter tree.

Source code in jax_russell/__init__.py
25
26
27
28
class StockOptionRBTree(StockOptionMixin, RendlemanBartterBinomialTree):  # type: ignore[misc]
    """Stock option Rendleman Bartter tree."""

    __doc__ += "" if StockOptionMixin.__doc__ is None else StockOptionMixin.__doc__

ValuationModel

Bases: ABC

Abstract class for valuation methods.

Source code in jax_russell/base.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class ValuationModel(abc.ABC):
    """Abstract class for valuation methods."""

    argnums = list(range(5))

    @abc.abstractmethod
    @partial(jax.jit, static_argnums=0)
    def value(
        self,
        start_price: jaxtyping.Float[
            jaxtyping.Array,
            "#contracts",
        ],
        volatility: jaxtyping.Float[jaxtyping.Array, "#contracts"],
        time_to_expiration: jaxtyping.Float[jaxtyping.Array, "#contracts"],
        risk_free_rate: jaxtyping.Float[jaxtyping.Array, "#contracts"],
        cost_of_carry: jaxtyping.Float[jaxtyping.Array, "#contracts"],
        is_call: jaxtyping.Float[jaxtyping.Array, "#contracts"],
        strike: jaxtyping.Float[jaxtyping.Array, "#contracts"],
    ) -> jaxtyping.Float[jaxtyping.Array, "#contracts"]:
        """Calculate the value of an option.

        This method is used internally by `__call__()`, and should return the value of options.
        By default, `__call__()` is a pass through to `value()`, but many available mixins overwrite this behavior to pass arguments to `value()`.
        In these cases, this allows the single, general method `value()` to implement valuations, while leveraging `__call__()` for security-specific argument logic and meaningful autodifferentiation.
        """  # noqa

    @partial(jax.jit, static_argnums=0)
    def __call__(self, *args, **kwargs):
        """Value arrays of options.

        By default, `__call__` checks its arguments against `value()` and passes them through.

        Returns:
            jnp.array: option values
        """
        inspect.signature(self.value).bind(*args, **kwargs)
        return self.value(*jnp.broadcast_arrays(*args), **kwargs)

    @partial(jax.jit, static_argnums=0)
    def first_order(self, *args, **kwargs):
        """Automatically calculate first-order greeks.

        Returns:
            _type_: _description_
        """
        inspect.signature(self).bind(*args, **kwargs)
        return jnp.hstack(
            jax.jacfwd(
                self,
                range(len(args)) if self.argnums is None else self.argnums,
            )(*args, **kwargs)
        )

    @partial(jax.jit, static_argnums=0)
    def second_order(self, *args, **kwargs):
        """Automatically calculate second-order greeks.

        Returns:
            _type_: _description_
        """
        inspect.signature(self).bind(*args, **kwargs)
        return jnp.concatenate(
            jax.jacfwd(
                self.first_order,
                range(len(args)) if self.argnums is None else self.argnums,
                # self.argnums,
            )(*args, **kwargs),
            axis=-1,
        )

    def solve_implied(
        self,
        expected_option_values,
        init_params,
        **kwargs,
    ):
        """Solve for an implied value, usually volatility.

        This method allows the flexibility to solve for any combination of values used in the valuation method's `__call__()` signature.
        For example, passing `{"risk_free_rate": jnp.array([0.05]),"volatility":jnp.array([.5])}` will solve for the implied values of both volatility and the risk free rate.

        Args:
            expected_option_values jnp.array: option values, typically observed market prices
            init_params dict[jnp.array]: initial guesses to begin solve optimization

        Returns:
            params, state: the parameters and state returned by a `jaxopt` optimizer `run()`
        """  # noqa: E501
        signature = inspect.signature(self.__call__)
        # inspect signature using bind to make sure all args have been passed
        signature.bind(**{**init_params, **kwargs})

        @jax.jit
        def objective(params, expected, kwargs):
            bound_arguments = signature.bind(**{**params, **kwargs})
            residuals = expected - self(*bound_arguments.args, **bound_arguments.kwargs)
            return jnp.mean(residuals**2)

        solver = jaxopt.BFGS(
            objective,
        )
        res = solver.run(
            init_params,
            expected=expected_option_values,
            kwargs=kwargs,
        )
        return res

__call__(*args, **kwargs)

Value arrays of options.

By default, __call__ checks its arguments against value() and passes them through.

Returns:

Type Description

jnp.array: option values

Source code in jax_russell/base.py
65
66
67
68
69
70
71
72
73
74
75
@partial(jax.jit, static_argnums=0)
def __call__(self, *args, **kwargs):
    """Value arrays of options.

    By default, `__call__` checks its arguments against `value()` and passes them through.

    Returns:
        jnp.array: option values
    """
    inspect.signature(self.value).bind(*args, **kwargs)
    return self.value(*jnp.broadcast_arrays(*args), **kwargs)

first_order(*args, **kwargs)

Automatically calculate first-order greeks.

Returns:

Name Type Description
_type_

description

Source code in jax_russell/base.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@partial(jax.jit, static_argnums=0)
def first_order(self, *args, **kwargs):
    """Automatically calculate first-order greeks.

    Returns:
        _type_: _description_
    """
    inspect.signature(self).bind(*args, **kwargs)
    return jnp.hstack(
        jax.jacfwd(
            self,
            range(len(args)) if self.argnums is None else self.argnums,
        )(*args, **kwargs)
    )

second_order(*args, **kwargs)

Automatically calculate second-order greeks.

Returns:

Name Type Description
_type_

description

Source code in jax_russell/base.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@partial(jax.jit, static_argnums=0)
def second_order(self, *args, **kwargs):
    """Automatically calculate second-order greeks.

    Returns:
        _type_: _description_
    """
    inspect.signature(self).bind(*args, **kwargs)
    return jnp.concatenate(
        jax.jacfwd(
            self.first_order,
            range(len(args)) if self.argnums is None else self.argnums,
            # self.argnums,
        )(*args, **kwargs),
        axis=-1,
    )

solve_implied(expected_option_values, init_params, **kwargs)

Solve for an implied value, usually volatility.

This method allows the flexibility to solve for any combination of values used in the valuation method's __call__() signature. For example, passing {"risk_free_rate": jnp.array([0.05]),"volatility":jnp.array([.5])} will solve for the implied values of both volatility and the risk free rate.

Parameters:

Name Type Description Default
expected_option_values array

option values, typically observed market prices

required
init_params dict[array]

initial guesses to begin solve optimization

required

Returns:

Type Description

params, state: the parameters and state returned by a jaxopt optimizer run()

Source code in jax_russell/base.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def solve_implied(
    self,
    expected_option_values,
    init_params,
    **kwargs,
):
    """Solve for an implied value, usually volatility.

    This method allows the flexibility to solve for any combination of values used in the valuation method's `__call__()` signature.
    For example, passing `{"risk_free_rate": jnp.array([0.05]),"volatility":jnp.array([.5])}` will solve for the implied values of both volatility and the risk free rate.

    Args:
        expected_option_values jnp.array: option values, typically observed market prices
        init_params dict[jnp.array]: initial guesses to begin solve optimization

    Returns:
        params, state: the parameters and state returned by a `jaxopt` optimizer `run()`
    """  # noqa: E501
    signature = inspect.signature(self.__call__)
    # inspect signature using bind to make sure all args have been passed
    signature.bind(**{**init_params, **kwargs})

    @jax.jit
    def objective(params, expected, kwargs):
        bound_arguments = signature.bind(**{**params, **kwargs})
        residuals = expected - self(*bound_arguments.args, **bound_arguments.kwargs)
        return jnp.mean(residuals**2)

    solver = jaxopt.BFGS(
        objective,
    )
    res = solver.run(
        init_params,
        expected=expected_option_values,
        kwargs=kwargs,
    )
    return res

value(start_price, volatility, time_to_expiration, risk_free_rate, cost_of_carry, is_call, strike) abstractmethod

Calculate the value of an option.

This method is used internally by __call__(), and should return the value of options. By default, __call__() is a pass through to value(), but many available mixins overwrite this behavior to pass arguments to value(). In these cases, this allows the single, general method value() to implement valuations, while leveraging __call__() for security-specific argument logic and meaningful autodifferentiation.

Source code in jax_russell/base.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@abc.abstractmethod
@partial(jax.jit, static_argnums=0)
def value(
    self,
    start_price: jaxtyping.Float[
        jaxtyping.Array,
        "#contracts",
    ],
    volatility: jaxtyping.Float[jaxtyping.Array, "#contracts"],
    time_to_expiration: jaxtyping.Float[jaxtyping.Array, "#contracts"],
    risk_free_rate: jaxtyping.Float[jaxtyping.Array, "#contracts"],
    cost_of_carry: jaxtyping.Float[jaxtyping.Array, "#contracts"],
    is_call: jaxtyping.Float[jaxtyping.Array, "#contracts"],
    strike: jaxtyping.Float[jaxtyping.Array, "#contracts"],
) -> jaxtyping.Float[jaxtyping.Array, "#contracts"]:
    """Calculate the value of an option.

    This method is used internally by `__call__()`, and should return the value of options.
    By default, `__call__()` is a pass through to `value()`, but many available mixins overwrite this behavior to pass arguments to `value()`.
    In these cases, this allows the single, general method `value()` to implement valuations, while leveraging `__call__()` for security-specific argument logic and meaningful autodifferentiation.
    """  # noqa