|
|
from __future__ import annotations
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
|
from torchgen.api import cpp
|
|
|
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
|
|
from torchgen.gen import pythonify_default
|
|
|
from torchgen.model import (
|
|
|
Argument,
|
|
|
BaseTy,
|
|
|
BaseType,
|
|
|
FunctionSchema,
|
|
|
ListType,
|
|
|
NativeFunction,
|
|
|
OptionalType,
|
|
|
Return,
|
|
|
Type,
|
|
|
Variant,
|
|
|
)
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from collections.abc import Iterable, Sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_function_signature(
|
|
|
name: str, arguments: Iterable[str] = (), return_type: str | None = None
|
|
|
) -> str:
|
|
|
if not isinstance(arguments, (list, tuple)):
|
|
|
arguments = tuple(arguments)
|
|
|
return_type = f" -> {return_type}" if return_type is not None else ""
|
|
|
|
|
|
sig = f"def {name}({', '.join(arguments)}){return_type}: ..."
|
|
|
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
|
|
|
return sig
|
|
|
|
|
|
lines = [
|
|
|
f"def {name}(",
|
|
|
*(f" {arg}," for arg in arguments),
|
|
|
f"){return_type}: ...",
|
|
|
]
|
|
|
sig = "\n".join(lines)
|
|
|
if all(len(line) <= 80 for line in lines):
|
|
|
return sig
|
|
|
|
|
|
|
|
|
return sig.removesuffix(" ...") + " # fmt: skip\n ..."
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonReturns:
|
|
|
returns: tuple[Return, ...]
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonArgument:
|
|
|
name: str
|
|
|
type: Type
|
|
|
default: str | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_init: str | None
|
|
|
|
|
|
|
|
|
|
|
|
def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
|
|
|
type_str = (
|
|
|
argument_type_str(self.type, symint=symint)
|
|
|
.replace("const ", "")
|
|
|
.replace(" &", "")
|
|
|
)
|
|
|
|
|
|
name = self.name
|
|
|
|
|
|
|
|
|
|
|
|
if name == "self" and type_str in ["Tensor", "Number"] and not method:
|
|
|
name = "input"
|
|
|
|
|
|
|
|
|
if self.default is not None:
|
|
|
default = {
|
|
|
"nullptr": "None",
|
|
|
"::std::nullopt": "None",
|
|
|
"std::nullopt": "None",
|
|
|
"{}": "None",
|
|
|
}.get(self.default, self.default)
|
|
|
return f"{type_str} {name}={default}"
|
|
|
else:
|
|
|
return f"{type_str} {name}"
|
|
|
|
|
|
def argument_str_pyi(
|
|
|
self, *, method: bool = False, deprecated: bool = False
|
|
|
) -> str:
|
|
|
type_str = argument_type_str_pyi(self.type)
|
|
|
|
|
|
name = self.name
|
|
|
|
|
|
|
|
|
|
|
|
if name == "self" and type_str == "Tensor" and not method and not deprecated:
|
|
|
name = "input"
|
|
|
|
|
|
if name == "from":
|
|
|
name += "_"
|
|
|
|
|
|
|
|
|
if name == "out" and type_str == "Tensor" and not deprecated:
|
|
|
type_str = f"{type_str} | None".replace(" | None | None", " | None")
|
|
|
|
|
|
|
|
|
treat_as_no_default = (
|
|
|
deprecated
|
|
|
and isinstance(self, PythonOutArgument)
|
|
|
and self.default == "None"
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.default is not None and not treat_as_no_default:
|
|
|
if (
|
|
|
isinstance(self.type, ListType)
|
|
|
and self.type.elem == BaseType(BaseTy.int)
|
|
|
and self.default.startswith("{")
|
|
|
and self.default.endswith("}")
|
|
|
):
|
|
|
default = (
|
|
|
"(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
|
|
|
)
|
|
|
else:
|
|
|
default = {
|
|
|
"nullptr": "None",
|
|
|
"::std::nullopt": "None",
|
|
|
"std::nullopt": "None",
|
|
|
"{}": "None",
|
|
|
"c10::MemoryFormat::Contiguous": "contiguous_format",
|
|
|
"QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
|
|
|
}.get(self.default, self.default)
|
|
|
return f"{name}: {type_str} = {default}"
|
|
|
else:
|
|
|
return f"{name}: {type_str}"
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonOutArgument(PythonArgument):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs: tuple[PythonArgument, ...]
|
|
|
|
|
|
@staticmethod
|
|
|
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
|
|
|
if not outputs:
|
|
|
return None
|
|
|
|
|
|
size = len(outputs)
|
|
|
if size == 1:
|
|
|
return PythonOutArgument(
|
|
|
name=outputs[0].name,
|
|
|
type=outputs[0].type,
|
|
|
default="None",
|
|
|
default_init=None,
|
|
|
outputs=outputs,
|
|
|
)
|
|
|
elif size > 1:
|
|
|
if any(not a.type.is_tensor_like() for a in outputs):
|
|
|
raise RuntimeError(f"Unsupported output type: {outputs}")
|
|
|
return PythonOutArgument(
|
|
|
name="out",
|
|
|
|
|
|
type=ListType(BaseType(BaseTy.Tensor), size),
|
|
|
default="None",
|
|
|
default_init=None,
|
|
|
outputs=outputs,
|
|
|
)
|
|
|
raise AssertionError(r"Unexpected PythonOutArgument size")
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonSignature:
|
|
|
|
|
|
name: str
|
|
|
|
|
|
|
|
|
|
|
|
input_args: tuple[PythonArgument, ...]
|
|
|
|
|
|
|
|
|
|
|
|
input_kwargs: tuple[PythonArgument, ...]
|
|
|
|
|
|
output_args: PythonOutArgument | None
|
|
|
|
|
|
|
|
|
returns: PythonReturns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_options_args: tuple[PythonArgument, ...]
|
|
|
|
|
|
|
|
|
method: bool
|
|
|
|
|
|
@property
|
|
|
def deprecated(self) -> bool:
|
|
|
return False
|
|
|
|
|
|
def arguments(
|
|
|
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
|
|
|
) -> tuple[PythonArgument | PythonOutArgument, ...]:
|
|
|
result: list[PythonArgument | PythonOutArgument] = []
|
|
|
result.extend(self.input_args)
|
|
|
result.extend(self.input_kwargs)
|
|
|
if self.output_args is not None and not skip_outputs:
|
|
|
result.append(self.output_args)
|
|
|
if not skip_tensor_options:
|
|
|
result.extend(self.tensor_options_args)
|
|
|
return tuple(result)
|
|
|
|
|
|
def arguments_count(self) -> int:
|
|
|
return len(self.arguments())
|
|
|
|
|
|
def output_idx(self) -> int:
|
|
|
return len(self.input_args) + len(self.input_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
|
|
args = self.arguments(skip_outputs=skip_outputs)
|
|
|
schema_formals: list[str] = [
|
|
|
a.argument_str(method=self.method, symint=symint) for a in args
|
|
|
]
|
|
|
positional_argc = len(self.input_args)
|
|
|
if len(schema_formals) > positional_argc:
|
|
|
schema_formals.insert(positional_argc, "*")
|
|
|
|
|
|
return f"{self.name}({', '.join(schema_formals)})"
|
|
|
|
|
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
|
|
args = self.arguments(skip_outputs=skip_outputs)
|
|
|
schema_formals: list[str] = [
|
|
|
a.argument_str_pyi(method=self.method) for a in args
|
|
|
]
|
|
|
positional_argc = len(self.input_args)
|
|
|
if len(schema_formals) > positional_argc:
|
|
|
schema_formals.insert(positional_argc, "*")
|
|
|
|
|
|
|
|
|
returns_str = returns_str_pyi(self)
|
|
|
|
|
|
if self.method:
|
|
|
schema_formals.insert(0, "self")
|
|
|
return format_function_signature(self.name, schema_formals, returns_str)
|
|
|
|
|
|
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
|
|
|
|
|
args = self.arguments(skip_outputs=skip_outputs)
|
|
|
schema_formals: list[str] = [
|
|
|
a.argument_str_pyi(method=self.method) for a in args
|
|
|
]
|
|
|
|
|
|
num_args = self.arguments_count()
|
|
|
if num_args == 0:
|
|
|
return None
|
|
|
|
|
|
num_positionalargs = len(self.input_args)
|
|
|
|
|
|
vararg_type = args[0].type
|
|
|
if not (
|
|
|
isinstance(vararg_type, ListType)
|
|
|
and str(vararg_type.elem) in ["int", "SymInt"]
|
|
|
and num_positionalargs == 1
|
|
|
):
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(vararg_type, ListType)
|
|
|
schema_formals[0] = (
|
|
|
"*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
|
|
|
)
|
|
|
|
|
|
returns_str = returns_str_pyi(self)
|
|
|
|
|
|
if self.method:
|
|
|
schema_formals.insert(0, "self")
|
|
|
return format_function_signature(self.name, schema_formals, returns_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonSignatureDeprecated(PythonSignature):
|
|
|
|
|
|
deprecated_schema: FunctionSchema
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deprecated_args_exprs: tuple[str, ...]
|
|
|
|
|
|
@property
|
|
|
def deprecated(self) -> bool:
|
|
|
return True
|
|
|
|
|
|
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
|
|
return (
|
|
|
PythonSignature.signature_str(
|
|
|
self, skip_outputs=skip_outputs, symint=symint
|
|
|
)
|
|
|
+ "|deprecated"
|
|
|
)
|
|
|
|
|
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
|
|
args = self.arguments(skip_outputs=skip_outputs)
|
|
|
schema_formals: list[str] = [
|
|
|
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
|
|
|
]
|
|
|
positional_argc = len(self.input_args)
|
|
|
if len(schema_formals) > positional_argc:
|
|
|
schema_formals.insert(positional_argc, "*")
|
|
|
|
|
|
returns_str = returns_str_pyi(self)
|
|
|
return format_function_signature(self.name, schema_formals, returns_str)
|
|
|
|
|
|
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonSignatureNativeFunctionPair:
|
|
|
signature: PythonSignature
|
|
|
function: NativeFunction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonSignatureGroup:
|
|
|
|
|
|
|
|
|
|
|
|
signature: PythonSignature
|
|
|
|
|
|
|
|
|
base: NativeFunction
|
|
|
|
|
|
|
|
|
outplace: NativeFunction | None
|
|
|
|
|
|
@classmethod
|
|
|
def from_pairs(
|
|
|
cls,
|
|
|
functional: PythonSignatureNativeFunctionPair,
|
|
|
out: PythonSignatureNativeFunctionPair | None,
|
|
|
) -> PythonSignatureGroup:
|
|
|
if out is None:
|
|
|
return PythonSignatureGroup(
|
|
|
signature=functional.signature,
|
|
|
base=functional.function,
|
|
|
outplace=None,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
signature_kwargs = out.signature.__dict__.copy()
|
|
|
|
|
|
|
|
|
|
|
|
signature_kwargs["tensor_options_args"] = (
|
|
|
functional.signature.tensor_options_args
|
|
|
)
|
|
|
|
|
|
return PythonSignatureGroup(
|
|
|
signature=type(out.signature)(**signature_kwargs),
|
|
|
base=functional.function,
|
|
|
outplace=out.function,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class DispatchLambdaArgument:
|
|
|
name: str
|
|
|
type_str: str
|
|
|
is_out_arg: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class PythonArgParserOutputExpr:
|
|
|
|
|
|
name: str
|
|
|
|
|
|
|
|
|
expr: str
|
|
|
|
|
|
|
|
|
|
|
|
index: int
|
|
|
|
|
|
|
|
|
argument: PythonArgument
|
|
|
|
|
|
@property
|
|
|
def is_none_expr(self) -> str:
|
|
|
return f"_r.isNone({self.index})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class DispatchLambdaArgumentExprs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exprs: Sequence[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inits: Sequence[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
|
|
|
return CppSignatureGroup.from_native_function(f, method=method).signature
|
|
|
|
|
|
|
|
|
def has_tensor_options(f: NativeFunction) -> bool:
|
|
|
return f.func.arguments.tensor_options is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def argument_type_str(
|
|
|
t: Type, *, simple_type: bool = False, symint: bool = True
|
|
|
) -> str:
|
|
|
if isinstance(t, BaseType):
|
|
|
if t.name == BaseTy.int:
|
|
|
return "int64_t"
|
|
|
elif t.name == BaseTy.float:
|
|
|
return "double"
|
|
|
elif t.name == BaseTy.str:
|
|
|
return "c10::string_view"
|
|
|
elif t.name in [
|
|
|
BaseTy.Tensor,
|
|
|
BaseTy.bool,
|
|
|
BaseTy.QScheme,
|
|
|
BaseTy.Scalar,
|
|
|
BaseTy.ScalarType,
|
|
|
BaseTy.Generator,
|
|
|
BaseTy.Storage,
|
|
|
BaseTy.Layout,
|
|
|
BaseTy.Device,
|
|
|
BaseTy.DeviceIndex,
|
|
|
BaseTy.MemoryFormat,
|
|
|
BaseTy.Dimname,
|
|
|
BaseTy.Stream,
|
|
|
BaseTy.SymInt,
|
|
|
]:
|
|
|
|
|
|
return t.name.name
|
|
|
|
|
|
elif isinstance(t, OptionalType):
|
|
|
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
|
|
|
return f"{elem}?"
|
|
|
elif isinstance(t, ListType):
|
|
|
size = t.size if not simple_type else None
|
|
|
if str(t.elem) == "bool":
|
|
|
assert t.size is not None
|
|
|
return f"::std::array<bool,{t.size}>"
|
|
|
elif str(t.elem) == "int":
|
|
|
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
|
|
|
elif str(t.elem) == "SymInt":
|
|
|
if symint:
|
|
|
return (
|
|
|
f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
|
|
|
)
|
|
|
else:
|
|
|
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
|
|
|
elif str(t.elem) == "Tensor":
|
|
|
return f"TensorList[{size}]" if size is not None else "TensorList"
|
|
|
elif str(t.elem) == "Scalar":
|
|
|
return f"ScalarList[{size}]" if size is not None else "ScalarList"
|
|
|
elif str(t.elem) == "Tensor?":
|
|
|
if simple_type:
|
|
|
return "c10::List<::std::optional<Tensor>>"
|
|
|
else:
|
|
|
return "const c10::List<::std::optional<Tensor>> &"
|
|
|
elif str(t.elem) == "Dimname":
|
|
|
return f"DimnameList[{size}]" if size is not None else "DimnameList"
|
|
|
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
|
|
|
return f"ArrayRef<{elem}>"
|
|
|
|
|
|
raise RuntimeError(f"unrecognized type {repr(t)}")
|
|
|
|
|
|
|
|
|
def argument_type_size(t: Type) -> int | None:
|
|
|
l = t.is_list_like()
|
|
|
if l is not None and str(l.elem) != "bool":
|
|
|
return l.size
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
|
|
|
def argument(a: Argument) -> PythonArgument:
|
|
|
return PythonArgument(
|
|
|
name=a.name,
|
|
|
type=a.type,
|
|
|
|
|
|
default=(
|
|
|
str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
|
|
|
if a.default is not None
|
|
|
else None
|
|
|
),
|
|
|
default_init=None,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def signature(
|
|
|
f: NativeFunction, *, method: bool = False, pyi: bool = False
|
|
|
) -> PythonSignature:
|
|
|
return signature_from_schema(
|
|
|
f.func, category_override=f.category_override, method=method, pyi=pyi
|
|
|
)
|
|
|
|
|
|
|
|
|
def signature_from_schema(
|
|
|
func: FunctionSchema,
|
|
|
*,
|
|
|
category_override: str | None,
|
|
|
method: bool = False,
|
|
|
pyi: bool = False,
|
|
|
) -> PythonSignature:
|
|
|
args: list[Argument] = []
|
|
|
args.extend(func.arguments.pre_self_positional)
|
|
|
|
|
|
if not method and func.arguments.self_arg is not None:
|
|
|
args.append(func.arguments.self_arg.argument)
|
|
|
args.extend(func.arguments.post_self_positional)
|
|
|
args.extend(func.arguments.pre_tensor_options_kwarg_only)
|
|
|
|
|
|
|
|
|
args.extend(func.arguments.post_tensor_options_kwarg_only)
|
|
|
args.extend(func.arguments.out)
|
|
|
|
|
|
input_arg_set = {a.name for a in func.arguments.flat_positional}
|
|
|
kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
|
|
|
out_arg_set = {a.name for a in func.arguments.out}
|
|
|
|
|
|
input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
|
|
|
input_kwargs = tuple(
|
|
|
map(argument, filter(lambda a: a.name in kwarg_only_set, args))
|
|
|
)
|
|
|
outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
has_tensor_input_arg = any(
|
|
|
a.type.is_tensor_like() for a in func.arguments.flat_non_out
|
|
|
)
|
|
|
if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
|
|
|
raise ValueError(
|
|
|
"argument named requires_grad is reserved, should not explicitly add it in the schema"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
|
|
|
|
|
|
name: str = cpp.name(func)
|
|
|
is_factory_function = category_override == "factory" or (
|
|
|
has_tensor_return and not has_tensor_input_arg
|
|
|
)
|
|
|
is_like_or_new_function = (
|
|
|
category_override in ("new", "like")
|
|
|
or name.startswith("new_")
|
|
|
or name.endswith("_like")
|
|
|
)
|
|
|
is_dummy_function = category_override == "dummy"
|
|
|
|
|
|
tensor_options_args: list[PythonArgument] = []
|
|
|
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
|
|
|
|
|
|
def topt_default_init(name: str) -> str | None:
|
|
|
topt_args = func.arguments.tensor_options
|
|
|
if topt_args is None:
|
|
|
return None
|
|
|
a = getattr(topt_args, name)
|
|
|
if a.default is None or a.default == "None":
|
|
|
return None
|
|
|
return cpp.default_expr(a.default, a.type, symint=False)
|
|
|
|
|
|
tensor_options_args.append(
|
|
|
PythonArgument(
|
|
|
name="dtype",
|
|
|
type=OptionalType(BaseType(BaseTy.ScalarType)),
|
|
|
default="None",
|
|
|
default_init=(
|
|
|
None if is_like_or_new_function else topt_default_init("dtype")
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
tensor_options_args.append(
|
|
|
PythonArgument(
|
|
|
name="layout",
|
|
|
type=OptionalType(BaseType(BaseTy.Layout)),
|
|
|
default="None",
|
|
|
default_init=(
|
|
|
None if is_like_or_new_function else topt_default_init("layout")
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
tensor_options_args.append(
|
|
|
PythonArgument(
|
|
|
name="device",
|
|
|
type=OptionalType(BaseType(BaseTy.Device)),
|
|
|
default="None",
|
|
|
default_init=(
|
|
|
None
|
|
|
if is_like_or_new_function
|
|
|
else (
|
|
|
topt_default_init("device")
|
|
|
or "torch::tensors::get_default_device()"
|
|
|
)
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
tensor_options_args.append(
|
|
|
PythonArgument(
|
|
|
name="pin_memory",
|
|
|
type=OptionalType(BaseType(BaseTy.bool)),
|
|
|
default="False",
|
|
|
default_init=None,
|
|
|
)
|
|
|
)
|
|
|
tensor_options_args.append(
|
|
|
PythonArgument(
|
|
|
name="requires_grad",
|
|
|
type=OptionalType(BaseType(BaseTy.bool)),
|
|
|
default="False",
|
|
|
default_init=None,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
returns = PythonReturns(returns=func.returns)
|
|
|
|
|
|
return PythonSignature(
|
|
|
name=str(func.name.name),
|
|
|
input_args=input_args,
|
|
|
input_kwargs=input_kwargs,
|
|
|
output_args=PythonOutArgument.from_outputs(outputs),
|
|
|
tensor_options_args=tuple(tensor_options_args),
|
|
|
returns=returns,
|
|
|
method=method,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
|
|
|
if len(returns) <= 1 or all(r.name is None for r in returns):
|
|
|
return []
|
|
|
else:
|
|
|
if any(r.name is None for r in returns):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise ValueError("Unnamed field is not supported by codegen")
|
|
|
|
|
|
return [str(r.name) for r in returns]
|
|
|
|
|
|
|
|
|
def argument_type_str_pyi(t: Type) -> str:
|
|
|
add_optional = False
|
|
|
if isinstance(t, OptionalType):
|
|
|
t = t.elem
|
|
|
add_optional = True
|
|
|
|
|
|
ret = ""
|
|
|
if isinstance(t, BaseType):
|
|
|
if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
|
|
|
ret = "_int"
|
|
|
if t.name == BaseTy.SymInt:
|
|
|
ret = "_int | SymInt"
|
|
|
elif t.name == BaseTy.float:
|
|
|
ret = "_float"
|
|
|
elif t.name == BaseTy.str:
|
|
|
ret = "str"
|
|
|
elif t.name == BaseTy.Scalar:
|
|
|
ret = "Number | _complex"
|
|
|
elif t.name == BaseTy.ScalarType:
|
|
|
ret = "_dtype"
|
|
|
elif t.name == BaseTy.bool:
|
|
|
ret = "_bool"
|
|
|
elif t.name == BaseTy.QScheme:
|
|
|
ret = "_qscheme"
|
|
|
elif t.name == BaseTy.Layout:
|
|
|
ret = "_layout"
|
|
|
elif t.name == BaseTy.Device:
|
|
|
ret = "DeviceLikeType | None"
|
|
|
elif t.name == BaseTy.MemoryFormat:
|
|
|
ret = "memory_format"
|
|
|
elif t.name == BaseTy.Dimname:
|
|
|
ret = "str | EllipsisType | None"
|
|
|
elif t.name == BaseTy.Storage:
|
|
|
ret = "Storage | UntypedStorage"
|
|
|
elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
|
|
|
|
|
|
ret = t.name.name
|
|
|
|
|
|
elif isinstance(t, ListType):
|
|
|
if str(t.elem) == "int":
|
|
|
ret = "_int | _size" if t.size is not None else "_size"
|
|
|
elif t.is_tensor_like():
|
|
|
|
|
|
|
|
|
|
|
|
add_optional = True
|
|
|
ret = (
|
|
|
"Tensor | tuple[Tensor, ...] | list[Tensor]"
|
|
|
if t.size is not None
|
|
|
else "tuple[Tensor, ...] | list[Tensor]"
|
|
|
)
|
|
|
elif str(t.elem) == "float":
|
|
|
ret = "Sequence[_float]"
|
|
|
elif str(t.elem) == "SymInt" and t.size is not None:
|
|
|
elem = argument_type_str_pyi(t.elem)
|
|
|
ret = f"{elem} | Sequence[{elem}]"
|
|
|
else:
|
|
|
elem = argument_type_str_pyi(t.elem)
|
|
|
ret = f"Sequence[{elem}]"
|
|
|
|
|
|
else:
|
|
|
raise RuntimeError(f"unrecognized type {repr(t)}")
|
|
|
|
|
|
if add_optional:
|
|
|
ret = f"{ret} | None".replace(" | None | None", " | None")
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
def return_type_str_pyi(t: Type) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(t, OptionalType):
|
|
|
inner = return_type_str_pyi(t.elem)
|
|
|
return f"{inner} | None".replace(" | None | None", " | None")
|
|
|
|
|
|
if isinstance(t, BaseType):
|
|
|
if t.name == BaseTy.Device:
|
|
|
return "_device"
|
|
|
elif t.name == BaseTy.Dimname:
|
|
|
return "str | None"
|
|
|
else:
|
|
|
return argument_type_str_pyi(t)
|
|
|
|
|
|
if isinstance(t, ListType):
|
|
|
inner = return_type_str_pyi(t.elem)
|
|
|
return f"tuple[{inner}, ...]"
|
|
|
|
|
|
return argument_type_str_pyi(t)
|
|
|
|
|
|
|
|
|
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
|
|
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
|
|
|
structseq_name = signature.name
|
|
|
field_names = structseq_fieldnames(signature.returns.returns)
|
|
|
if field_names:
|
|
|
|
|
|
|
|
|
|
|
|
seq_type = f"tuple[{', '.join(python_returns)}]"
|
|
|
structseq_def_lines = [
|
|
|
f"class {structseq_name}({seq_type}): # fmt: skip",
|
|
|
]
|
|
|
for name, ret_type in zip(field_names, python_returns):
|
|
|
structseq_def_lines.extend(
|
|
|
[
|
|
|
" @property",
|
|
|
f" def {name}(self) -> {ret_type}: ...",
|
|
|
]
|
|
|
)
|
|
|
structseq_def_lines.extend(
|
|
|
[
|
|
|
" def __new__(",
|
|
|
" cls,",
|
|
|
f" sequence: {seq_type},",
|
|
|
" ) -> Self: # fmt: skip",
|
|
|
" ...",
|
|
|
f" n_fields: Final[_int] = {len(field_names)}",
|
|
|
f" n_sequence_fields: Final[_int] = {len(field_names)}",
|
|
|
" n_unnamed_fields: Final[_int] = 0",
|
|
|
" def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
|
|
|
"",
|
|
|
]
|
|
|
)
|
|
|
structseq_def = "\n".join(structseq_def_lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return structseq_name, structseq_def
|
|
|
return None
|
|
|
|
|
|
|
|
|
def returns_str_pyi(signature: PythonSignature) -> str:
|
|
|
field_names = structseq_fieldnames(signature.returns.returns)
|
|
|
if field_names:
|
|
|
return f"torch.return_types.{signature.name}"
|
|
|
|
|
|
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
|
|
|
if len(python_returns) > 1:
|
|
|
return "tuple[" + ", ".join(python_returns) + "]"
|
|
|
if len(python_returns) == 1:
|
|
|
return python_returns[0]
|
|
|
return "None"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dispatch_lambda_args(
|
|
|
ps: PythonSignature, f: NativeFunction, symint: bool = True
|
|
|
) -> tuple[DispatchLambdaArgument, ...]:
|
|
|
if isinstance(ps, PythonSignatureDeprecated):
|
|
|
schema = ps.deprecated_schema
|
|
|
else:
|
|
|
schema = f.func
|
|
|
|
|
|
|
|
|
cpp_args = cpp.arguments(
|
|
|
arguments=schema.arguments,
|
|
|
faithful=False,
|
|
|
symint=symint,
|
|
|
method=False,
|
|
|
cpp_no_default_args=f.cpp_no_default_args,
|
|
|
)
|
|
|
out_args: set[str] = {a.name for a in schema.arguments.out}
|
|
|
|
|
|
|
|
|
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
|
|
|
type_str = cpp_arg.type
|
|
|
is_out_arg = cpp_arg.name in out_args
|
|
|
if ps.method and cpp_arg.name == "self":
|
|
|
|
|
|
type_str = "const at::Tensor &"
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
|
|
|
if ensure_temp_safe:
|
|
|
type_str = {
|
|
|
"at::Tensor &": "at::Tensor",
|
|
|
}.get(type_str, type_str)
|
|
|
return DispatchLambdaArgument(
|
|
|
name=cpp_arg.name,
|
|
|
type_str=type_str,
|
|
|
is_out_arg=is_out_arg,
|
|
|
)
|
|
|
|
|
|
return tuple(map(dispatch_lambda_arg, cpp_args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPPORTED_RETURN_TYPES = {
|
|
|
"at::Tensor",
|
|
|
"::std::tuple<at::Tensor,at::Tensor>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
|
|
|
"::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
|
|
|
"::std::tuple<double,int64_t>",
|
|
|
"::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
|
|
|
"::std::vector<at::Tensor>",
|
|
|
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
|
|
|
"at::Scalar",
|
|
|
"bool",
|
|
|
"int64_t",
|
|
|
"void*",
|
|
|
"void",
|
|
|
"at::QScheme",
|
|
|
"double",
|
|
|
"at::IntArrayRef",
|
|
|
"at::ScalarType",
|
|
|
"at::Stream",
|
|
|
}
|
|
|
|
|
|
|
|
|
def dispatch_lambda_return_str(f: NativeFunction) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
returns_without_annotation = tuple(
|
|
|
Return(r.name, r.type, None) for r in f.func.returns
|
|
|
)
|
|
|
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
|
|
|
if return_str not in SUPPORTED_RETURN_TYPES:
|
|
|
raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
|
|
|
return return_str
|
|
|
|
|
|
|
|
|
def cpp_dispatch_target(f: NativeFunction) -> str:
|
|
|
symint = f.func.has_symint()
|
|
|
name = cpp.name(f.func, symint_overload=symint)
|
|
|
if Variant.method in f.variants:
|
|
|
return f"self.{name}"
|
|
|
if Variant.function in f.variants:
|
|
|
if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
|
|
|
namespace = "torch"
|
|
|
else:
|
|
|
namespace = "at"
|
|
|
return f"{namespace}::{name}"
|
|
|
raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
|
|
|
|
|
|
|
|
|
def cpp_dispatch_exprs(
|
|
|
f: NativeFunction,
|
|
|
*,
|
|
|
python_signature: PythonSignature | None = None,
|
|
|
) -> tuple[str, ...]:
|
|
|
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
|
|
|
|
|
|
exprs: tuple[str, ...] = ()
|
|
|
if not isinstance(python_signature, PythonSignatureDeprecated):
|
|
|
|
|
|
exprs = tuple(a.name for a in cpp_args)
|
|
|
else:
|
|
|
|
|
|
exprs = tuple(
|
|
|
filter(
|
|
|
lambda n: n != "out" or f.func.is_out_fn(),
|
|
|
python_signature.deprecated_args_exprs,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
if Variant.method in f.variants:
|
|
|
exprs = tuple(filter("self".__ne__, exprs))
|
|
|
|
|
|
return exprs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def arg_parser_unpack_method(
|
|
|
t: Type, default: str | None, default_init: str | None, *, symint: bool = True
|
|
|
) -> str:
|
|
|
has_default_init = default_init is not None
|
|
|
if has_default_init and str(t) not in (
|
|
|
"ScalarType?",
|
|
|
"ScalarType",
|
|
|
"Device",
|
|
|
"Device?",
|
|
|
"Layout",
|
|
|
"Layout?",
|
|
|
"bool",
|
|
|
"bool?",
|
|
|
):
|
|
|
raise RuntimeError(f"type '{t}' does not supported unpacking with default")
|
|
|
|
|
|
if isinstance(t, BaseType):
|
|
|
if t.name in [
|
|
|
BaseTy.Tensor,
|
|
|
BaseTy.Stream,
|
|
|
BaseTy.Storage,
|
|
|
BaseTy.Scalar,
|
|
|
BaseTy.Dimname,
|
|
|
]:
|
|
|
|
|
|
return t.name.name.lower()
|
|
|
elif t.name == BaseTy.ScalarType:
|
|
|
return "scalartypeWithDefault" if has_default_init else "scalartype"
|
|
|
elif t.name == BaseTy.Device:
|
|
|
return "deviceWithDefault" if has_default_init else "device"
|
|
|
elif t.name == BaseTy.DeviceIndex:
|
|
|
return "toInt64"
|
|
|
elif t.name == BaseTy.int:
|
|
|
return "toInt64"
|
|
|
elif t.name == BaseTy.SymInt:
|
|
|
return "toSymInt" if symint else "toInt64"
|
|
|
elif t.name == BaseTy.bool:
|
|
|
return "toBoolWithDefault" if has_default_init else "toBool"
|
|
|
elif t.name == BaseTy.float:
|
|
|
return "toDouble"
|
|
|
elif t.name == BaseTy.str:
|
|
|
return "stringView"
|
|
|
elif t.name == BaseTy.Layout:
|
|
|
return "layoutWithDefault" if has_default_init else "layout"
|
|
|
elif t.name == BaseTy.MemoryFormat:
|
|
|
return "memoryformat"
|
|
|
|
|
|
elif isinstance(t, OptionalType):
|
|
|
if str(t.elem) == "Tensor":
|
|
|
return "optionalTensor"
|
|
|
elif str(t.elem) == "Generator":
|
|
|
return "generator"
|
|
|
elif str(t.elem) == "Dimname[]":
|
|
|
return "toDimnameListOptional"
|
|
|
elif not has_default_init and default in (
|
|
|
None,
|
|
|
"None",
|
|
|
"::std::nullopt",
|
|
|
"std::nullopt",
|
|
|
):
|
|
|
|
|
|
return (
|
|
|
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
return arg_parser_unpack_method(
|
|
|
t.elem, default, default_init, symint=symint
|
|
|
)
|
|
|
|
|
|
elif isinstance(t, ListType):
|
|
|
if str(t.elem) == "Tensor":
|
|
|
|
|
|
return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
|
|
|
elif str(t.elem) == "Tensor?":
|
|
|
return "list_of_optional_tensors"
|
|
|
elif str(t.elem) == "Dimname":
|
|
|
|
|
|
return "dimnamelist"
|
|
|
elif str(t.elem) == "int":
|
|
|
|
|
|
return "intlist"
|
|
|
elif str(t.elem) == "float":
|
|
|
return "doublelist"
|
|
|
elif str(t.elem) == "SymInt":
|
|
|
|
|
|
return "symintlist" if symint else "intlist"
|
|
|
elif str(t.elem) == "Scalar":
|
|
|
return "scalarlist"
|
|
|
raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def arg_parser_output_expr(
|
|
|
arg_index: int, a: PythonArgument, *, symint: bool = True
|
|
|
) -> PythonArgParserOutputExpr:
|
|
|
has_default = a.default_init is not None
|
|
|
unpack_method = arg_parser_unpack_method(
|
|
|
t=a.type, default=a.default, default_init=a.default_init, symint=symint
|
|
|
)
|
|
|
default = f", {a.default_init}" if has_default else ""
|
|
|
expr = f"_r.{unpack_method}({arg_index}{default})"
|
|
|
|
|
|
return PythonArgParserOutputExpr(
|
|
|
name=a.name,
|
|
|
expr=expr,
|
|
|
index=arg_index,
|
|
|
argument=a,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def arg_parser_output_exprs(
|
|
|
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
|
|
) -> dict[str, PythonArgParserOutputExpr]:
|
|
|
return {
|
|
|
e.name: e
|
|
|
for i, a in enumerate(ps.arguments())
|
|
|
for e in (arg_parser_output_expr(i, a, symint=symint),)
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TENSOR_OPTIONS_FIELDS = {
|
|
|
"dtype": "ScalarType?",
|
|
|
"device": "Device?",
|
|
|
"layout": "Layout?",
|
|
|
"pin_memory": "bool?",
|
|
|
"requires_grad": "bool?",
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def dispatch_lambda_exprs(
|
|
|
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
|
|
) -> DispatchLambdaArgumentExprs:
|
|
|
|
|
|
|
|
|
|
|
|
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
|
|
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
|
|
|
inits: list[str] = []
|
|
|
lambda_args_exprs: dict[str, str] = {}
|
|
|
|
|
|
has_toptions = has_tensor_options(f)
|
|
|
|
|
|
|
|
|
for a in ps.arguments(skip_tensor_options=True):
|
|
|
name = a.name
|
|
|
arg_parser_expr = arg_parser_outputs[a.name].expr
|
|
|
|
|
|
if has_toptions and name == "self":
|
|
|
|
|
|
inits.extend(
|
|
|
[
|
|
|
f"auto self = {arg_parser_expr};",
|
|
|
]
|
|
|
)
|
|
|
lambda_args_exprs[name] = name
|
|
|
elif (
|
|
|
isinstance(a, PythonOutArgument)
|
|
|
and len(a.outputs) > 1
|
|
|
and f.func.is_out_fn()
|
|
|
):
|
|
|
inits.extend(
|
|
|
[
|
|
|
f"auto out = {arg_parser_expr};",
|
|
|
]
|
|
|
)
|
|
|
for i, out_arg in enumerate(a.outputs):
|
|
|
lambda_args_exprs[out_arg.name] = f"out[{i}]"
|
|
|
elif str(a.type) == "Dimname[]?":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inits.extend(
|
|
|
[
|
|
|
f"auto __{name} = {arg_parser_expr};",
|
|
|
f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;",
|
|
|
]
|
|
|
)
|
|
|
lambda_args_exprs[name] = name
|
|
|
else:
|
|
|
|
|
|
lambda_args_exprs[name] = arg_parser_expr
|
|
|
|
|
|
|
|
|
if ps.method:
|
|
|
lambda_args_exprs["self"] = "self"
|
|
|
|
|
|
|
|
|
tensor_options_args_names = [a.name for a in ps.tensor_options_args]
|
|
|
if has_toptions:
|
|
|
if f.func.is_out_fn():
|
|
|
raise RuntimeError(f"{f.func}: tensor options with output arg")
|
|
|
for a in ps.tensor_options_args:
|
|
|
if a.name not in TENSOR_OPTIONS_FIELDS:
|
|
|
raise RuntimeError(
|
|
|
f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
|
|
|
)
|
|
|
if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
|
|
|
raise RuntimeError(
|
|
|
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
|
|
|
)
|
|
|
if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
|
|
|
raise RuntimeError(
|
|
|
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
|
|
|
)
|
|
|
|
|
|
inits.append(
|
|
|
f"""\
|
|
|
const auto options = TensorOptions()
|
|
|
.dtype({arg_parser_outputs["dtype"].expr})
|
|
|
.device({arg_parser_outputs["device"].expr})
|
|
|
.layout({arg_parser_outputs["layout"].expr})
|
|
|
.requires_grad({arg_parser_outputs["requires_grad"].expr})
|
|
|
.pinned_memory({arg_parser_outputs["pin_memory"].expr});
|
|
|
torch::utils::maybe_initialize_device(options);
|
|
|
"""
|
|
|
)
|
|
|
lambda_args_exprs["options"] = "options"
|
|
|
|
|
|
|
|
|
|
|
|
if not has_toptions and tensor_options_args_names:
|
|
|
if "dtype" in tensor_options_args_names:
|
|
|
|
|
|
if not f.func.is_out_fn():
|
|
|
raise RuntimeError(
|
|
|
f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
|
|
|
)
|
|
|
if not all(a in tensor_options_args_names for a in ("layout", "device")):
|
|
|
raise RuntimeError(
|
|
|
f"{f.func}: incomplete tensor options for output check"
|
|
|
)
|
|
|
|
|
|
inits.append(
|
|
|
f"""\
|
|
|
check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr},
|
|
|
{arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr},
|
|
|
{arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr});
|
|
|
"""
|
|
|
)
|
|
|
|
|
|
if "requires_grad" not in tensor_options_args_names:
|
|
|
raise RuntimeError(
|
|
|
f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
|
|
|
)
|
|
|
|
|
|
return DispatchLambdaArgumentExprs(
|
|
|
exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
|
|
|
inits=inits,
|
|
|
)
|
|
|
|