package arrayjit

  1. Overview
  2. Docs
module Lazy = Utils.Lazy

The code for operating on n-dimensional arrays.

module Nd = Ndarray
module Tn = Tnode
module Debug_runtime = Utils.Debug_runtime
module Scope_id : sig ... end
type scope_id = Scope_id.t = {
  1. nd : Tn.t;
  2. scope_id : Base.int;
}
val sexp_of_scope_id : scope_id -> Sexplib0.Sexp.t
val equal_scope_id : scope_id -> scope_id -> Base.bool
val hash_fold_scope_id : Ppx_hash_lib.Std.Hash.state -> scope_id -> Ppx_hash_lib.Std.Hash.state
val hash_scope_id : scope_id -> Ppx_hash_lib.Std.Hash.hash_value
val compare_scope_id : scope_id -> scope_id -> Base.int

*** Low-level representation.

val get_scope : Tn.t -> scope_id
type t =
  1. | Noop
  2. | Comment of Base.string
  3. | Staged_compilation of Base.unit -> Base.unit
  4. | Seq of t * t
  5. | For_loop of {
    1. index : Indexing.symbol;
    2. from_ : Base.int;
    3. to_ : Base.int;
    4. body : t;
    5. trace_it : Base.bool;
    }
  6. | Zero_out of Tn.t
  7. | Set of {
    1. array : Tn.t;
    2. idcs : Indexing.axis_index Base.array;
    3. llv : float_t;
    4. mutable debug : Base.string;
    }
  8. | Set_local of scope_id * float_t

Cases: t -- code, float_t -- single number at some precision.

and float_t =
  1. | Local_scope of {
    1. id : scope_id;
    2. prec : Ops.prec;
    3. body : t;
    4. orig_indices : Indexing.axis_index Base.array;
    }
  2. | Get_local of scope_id
  3. | Get_global of Ops.global_identifier * Indexing.axis_index Base.array Base.option
  4. | Get of Tn.t * Indexing.axis_index Base.array
  5. | Binop of Ops.binop * float_t * float_t
  6. | Unop of Ops.unop * float_t
  7. | Constant of Base.float
  8. | Embed_index of Indexing.axis_index
val sexp_of_t : t -> Sexplib0.Sexp.t
val sexp_of_float_t : float_t -> Sexplib0.Sexp.t
val equal : t -> t -> Base.bool
val equal_float_t : float_t -> float_t -> Base.bool
val compare : t -> t -> Base.int
val compare_float_t : float_t -> float_t -> Base.int
val binop : op:Ops.binop -> rhs1:float_t -> rhs2:float_t -> float_t
val unop : op:Ops.unop -> rhs:float_t -> float_t
val flat_lines : t Base.List.t -> t Base.List.t
val unflat_lines : t list -> t
val comment_to_name : string -> string
val extract_block_name : t Base.List.t -> string

*** Optimization ***

type virtualize_settings = {
  1. mutable enable_device_only : Base.bool;
  2. mutable max_visits : Base.int;
  3. mutable max_tracing_dim : Base.int;
  4. mutable inline_scalar_constexprs : Base.bool;
}
val virtualize_settings : virtualize_settings
type visits =
  1. | Visits of Base.int
  2. | Recurrent
    (*

    A Recurrent visit is when there is an access prior to any assignment in an update.

    *)
val visits_of_sexp : Sexplib0.Sexp.t -> visits
val sexp_of_visits : visits -> Sexplib0.Sexp.t
val equal_visits : visits -> visits -> Base.bool
val visits : Base.int -> visits
val recurrent : visits
val is_visits : visits -> bool
val is_recurrent : visits -> bool
val visits_val : visits -> Base.int Stdlib.Option.t
val recurrent_val : visits -> unit Stdlib.Option.t
module Variants_of_visits : sig ... end
type traced_array = {
  1. nd : Tn.t;
  2. mutable computations : (Indexing.axis_index Base.array Base.option * t) Base.list;
    (*

    The computations (of the tensor node) are retrieved for optimization just as they are populated, so that the inlined code corresponds precisely to the changes to the arrays that would happen up till that point. Within the code blocks paired with an index tuple, all assignments and accesses must happen via the index tuple; if this is not the case for some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in assignment indices of virtual nodes.

    *)
  3. assignments : Base.int Base.array Base.Hash_set.t;
  4. accesses : (Base.int Base.array, visits) Base.Hashtbl.t;
    (*

    For dynamic indexes, we take a value of 0. This leads to an overestimate of visits, which is safe.

    *)
  5. mutable zero_initialized : Base.bool;
  6. mutable zeroed_out : Base.bool;
  7. mutable read_before_write : Base.bool;
    (*

    The node is read before it is written (i.e. it is recurrent).

    *)
  8. mutable read_only : Base.bool;
  9. mutable is_scalar_constexpr : Base.bool;
    (*

    True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned before accessed, is assigned at most once, and from an expression involving only constants or tensor nodes that were at the time is_scalar_constexpr.

    *)
}
val sexp_of_traced_array : traced_array -> Sexplib0.Sexp.t
val get_node : (Tn.t, traced_array) Base.Hashtbl.t -> Tn.t Base.Hashtbl.key -> traced_array
val partition_tf_with_comment : t Base.Array.t -> f:(t -> bool) -> t Base.Array.t * t Base.Array.t
val visit : is_assigned:Base.bool -> visits option -> visits
val is_constexpr_comp : (Tn.t, traced_array) Base.Hashtbl.t -> float_t -> Base.bool
val is_scalar_dims : Tn.t -> bool
val visit_llc : (Tn.t, traced_array) Base.Hashtbl.t -> (Indexing.symbol, Tn.t) Base.Hashtbl.t -> max_visits:int -> t -> unit
val check_and_store_virtual : traced_array -> Indexing.static_symbol Base.List.t -> t -> unit
val inline_computation : id:scope_id -> traced_array -> Indexing.static_symbol Base.List.t -> Indexing.axis_index Base.Array.t -> t option
val optimize_integer_pow : bool Base.ref
val unroll_pow : base:float_t -> exp:Base.int -> float_t
val virtual_llc : (Tn.t, traced_array) Base.Hashtbl.t -> (Indexing.symbol, Tn.t Base.Hashtbl.key) Base.Hashtbl.t -> Indexing.static_symbol Base.List.t -> t -> t
val cleanup_virtual_llc : (Indexing.symbol, Tn.t) Base.Hashtbl.t -> static_indices:Indexing.static_symbol Base.List.t -> t -> t
val substitute_float : var:float_t -> value:float_t -> float_t -> float_t
val substitute_proc : var:float_t -> value:float_t -> t -> t
val simplify_llc : t -> t
type traced_store = (Tn.t, traced_array) Base.Hashtbl.t
val sexp_of_traced_store : traced_store -> Sexplib0.Sexp.t
type optimized = traced_store * t
val optimize_proc : Indexing.static_symbol Base.List.t -> t -> (Tnode.t, traced_array) Base.Hashtbl.t * t
val code_hum_margin : int Base.ref
val pp_comma : Stdlib.Format.formatter -> unit -> unit
val pp_symbol : Stdlib.Format.formatter -> Indexing.symbol -> unit
val pp_static_symbol : Stdlib.Format.formatter -> Indexing.static_symbol -> unit
val pp_index : Stdlib.Format.formatter -> Indexing.axis_index -> unit
val pp_indices : Stdlib.Format.formatter -> Indexing.axis_index Base.Array.t -> unit
val fprint_function_header : ?name:string -> ?static_indices:Indexing.static_symbol list -> unit -> Stdlib.Format.formatter -> unit
val fprint_hum : ?ident_style: [< `Heuristic_ocannl | `Name_and_label | `Name_only Heuristic_ocannl ] -> ?name:string -> ?static_indices:Indexing.static_symbol list -> unit -> Stdlib.Format.formatter -> t -> unit
val compile_proc : unoptim_ll_source:Stdlib.Format.formatter Base.Option.t -> ll_source:Stdlib.Format.formatter Base.Option.t -> name:Base.string -> Indexing.static_symbol Base.list -> t -> (Tn.t, traced_array) Base.Hashtbl.t * t
val loop_over_dims : Base__Int.t Base.Array.t -> body:(Indexing.axis_index Base.Array.t -> t) -> t
OCaml

Innovation. Community. Security.