|
|
from __future__ import annotations
|
|
|
|
|
|
from torchgen.api import dispatcher
|
|
|
from torchgen.api.types import (
|
|
|
BaseCppType,
|
|
|
BaseCType,
|
|
|
Binding,
|
|
|
boolT,
|
|
|
ConstRefCType,
|
|
|
CType,
|
|
|
longT,
|
|
|
NamedCType,
|
|
|
tensorT,
|
|
|
)
|
|
|
from torchgen.model import (
|
|
|
Argument,
|
|
|
BaseTy,
|
|
|
BaseType,
|
|
|
FunctionSchema,
|
|
|
NativeFunction,
|
|
|
NativeFunctionsViewGroup,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_binding = Binding(
|
|
|
name="base",
|
|
|
nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
|
|
|
argument=Argument(
|
|
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
|
|
|
),
|
|
|
default=None,
|
|
|
)
|
|
|
mutated_view_binding = Binding(
|
|
|
name="mutated_view",
|
|
|
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
|
|
|
argument=Argument(
|
|
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
|
|
|
),
|
|
|
default=None,
|
|
|
)
|
|
|
mutated_view_idx_binding = Binding(
|
|
|
name="mutated_view_idx",
|
|
|
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
|
|
|
argument=Argument(
|
|
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
|
|
|
),
|
|
|
default=None,
|
|
|
)
|
|
|
reapply_views_binding = Binding(
|
|
|
name="reapply_views",
|
|
|
nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
|
|
|
argument=Argument(
|
|
|
name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
|
|
|
),
|
|
|
default=None,
|
|
|
)
|
|
|
|
|
|
InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
|
|
|
inverse_return_mode_binding = Binding(
|
|
|
name="inverse_return_mode",
|
|
|
nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
|
|
|
argument=Argument(
|
|
|
name="inverse_return_mode",
|
|
|
|
|
|
type=BaseType(BaseTy.bool),
|
|
|
default=None,
|
|
|
annotation=None,
|
|
|
),
|
|
|
default=None,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def name(
|
|
|
g: NativeFunctionsViewGroup,
|
|
|
*,
|
|
|
is_reverse: bool,
|
|
|
include_namespace: bool,
|
|
|
reapply_views: bool | None = None,
|
|
|
) -> str:
|
|
|
if reapply_views is None:
|
|
|
|
|
|
|
|
|
assert is_reverse
|
|
|
if is_reverse:
|
|
|
return reverse_name(g.view, include_namespace)
|
|
|
|
|
|
assert include_namespace
|
|
|
assert g.view_copy is not None
|
|
|
api_name = (
|
|
|
g.view.func.name.unambiguous_name()
|
|
|
if reapply_views
|
|
|
else g.view_copy.func.name.unambiguous_name()
|
|
|
)
|
|
|
return f"at::_ops::{api_name}::call"
|
|
|
|
|
|
|
|
|
def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
api_name = f.func.name.unambiguous_name()
|
|
|
|
|
|
if include_namespace:
|
|
|
return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
|
|
|
else:
|
|
|
return f"{api_name}_inverse"
|
|
|
|
|
|
|
|
|
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
|
|
|
|
|
|
|
|
|
|
|
|
args = func.arguments.flat_all
|
|
|
assert args[0].type == BaseType(BaseTy.Tensor)
|
|
|
non_self_args = args[1:]
|
|
|
non_self_value_bindings = [
|
|
|
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
|
|
|
]
|
|
|
|
|
|
all_bindings = [
|
|
|
inverse_return_mode_binding if is_reverse else reapply_views_binding
|
|
|
]
|
|
|
all_bindings.extend(non_self_value_bindings)
|
|
|
return all_bindings
|
|
|
|
|
|
|
|
|
def returns_type(func: FunctionSchema) -> CType:
|
|
|
|
|
|
assert len(func.returns) >= 1
|
|
|
for ret in func.returns:
|
|
|
assert ret.type.is_tensor_like()
|
|
|
|
|
|
|
|
|
return BaseCType(tensorT)
|
|
|
|
|
|
|
|
|
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
|
|
|
if is_reverse:
|
|
|
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
|
|
else:
|
|
|
return [base_binding, mutated_view_idx_binding]
|
|
|
|
|
|
|
|
|
def inner_call_index(func: FunctionSchema) -> Binding | None:
|
|
|
|
|
|
|
|
|
if len(func.returns) > 1 or (
|
|
|
len(func.returns) == 1 and func.returns[0].type.is_list_like()
|
|
|
):
|
|
|
return mutated_view_idx_binding
|
|
|
return None
|
|
|
|
|
|
|
|
|
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
|
|
args = func.arguments.flat_all
|
|
|
assert args[0].type == BaseType(BaseTy.Tensor)
|
|
|
non_self_args = args[1:]
|
|
|
|
|
|
|
|
|
non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
|
|
|
if not is_reverse:
|
|
|
|
|
|
return [base_binding] + non_self_bindings
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
index_binding = inner_call_index(func)
|
|
|
if index_binding is not None:
|
|
|
return [
|
|
|
base_binding,
|
|
|
mutated_view_binding,
|
|
|
inverse_return_mode_binding,
|
|
|
index_binding,
|
|
|
] + non_self_bindings
|
|
|
else:
|
|
|
return [
|
|
|
base_binding,
|
|
|
mutated_view_binding,
|
|
|
inverse_return_mode_binding,
|
|
|
] + non_self_bindings
|
|
|
|