mirror of https://github.com/Nonannet/copapy.git
Net and value types separated
This commit is contained in:
parent
df22438ffe
commit
6dcaa6797c
|
|
@ -35,23 +35,23 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any
|
||||||
assert isinstance(y, Sequence) or isinstance(y, vector)
|
assert isinstance(y, Sequence) or isinstance(y, vector)
|
||||||
y_set = {v for v in y}
|
y_set = {v for v in y}
|
||||||
|
|
||||||
edges = cpb.get_all_dag_edges_between([x.source], (net.source for net in y_set if isinstance(net, Net)))
|
edges = cpb.get_all_dag_edges_between([x.net.source], (v.net.source for v in y_set if isinstance(v, value)))
|
||||||
ordered_ops = cpb.stable_toposort(edges)
|
ordered_ops = cpb.stable_toposort(edges)
|
||||||
|
|
||||||
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
|
net_lookup = {net.source: net for node in ordered_ops for net in node.args}
|
||||||
grad_dict: dict[Net, unifloat] = dict()
|
grad_dict: dict[Net, unifloat] = dict()
|
||||||
|
|
||||||
def add_grad(val: value[Any], gradient_value: unifloat) -> None:
|
def add_grad(val: value[Any], gradient_value: unifloat) -> None:
|
||||||
grad_dict[val] = grad_dict.get(val, 0.0) + gradient_value
|
grad_dict[val.net] = grad_dict.get(val.net, 0.0) + gradient_value
|
||||||
|
|
||||||
for node in reversed(ordered_ops):
|
for node in reversed(ordered_ops):
|
||||||
#print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
|
#print(f"--> {'x' if node in net_lookup else ' '}", node, f"{net_lookup.get(node)}")
|
||||||
if node.args:
|
if node.args:
|
||||||
args: Sequence[Any] = list(node.args)
|
args: Sequence[Net] = list(node.args)
|
||||||
g = 1.0 if node is x.source else grad_dict[net_lookup[node]]
|
g = 1.0 if node is x.net.source else grad_dict[net_lookup[node]]
|
||||||
opn = node.name.split('_')[0]
|
opn = node.name.split('_')[0]
|
||||||
a: value[Any] = args[0]
|
a: value[float] = value(args[0])
|
||||||
b: value[Any] = args[1] if len(args) > 1 else a
|
b: value[float] = value(args[1]) if len(args) > 1 else a
|
||||||
|
|
||||||
if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
|
if opn in ['ge', 'gt', 'eq', 'ne', 'floordiv', 'bwand', 'bwor', 'bwxor']:
|
||||||
pass # Derivative is 0 for all ops returning integers
|
pass # Derivative is 0 for all ops returning integers
|
||||||
|
|
@ -119,9 +119,9 @@ def grad(x: Any, y: value[Any] | Sequence[value[Any]] | vector[Any] | matrix[Any
|
||||||
raise ValueError(f"Operation {opn} not yet supported for auto diff.")
|
raise ValueError(f"Operation {opn} not yet supported for auto diff.")
|
||||||
|
|
||||||
if isinstance(y, value):
|
if isinstance(y, value):
|
||||||
return grad_dict[y]
|
return grad_dict[y.net]
|
||||||
if isinstance(y, vector):
|
if isinstance(y, vector):
|
||||||
return vector(grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in y)
|
return vector(grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in y)
|
||||||
if isinstance(y, matrix):
|
if isinstance(y, matrix):
|
||||||
return matrix((grad_dict[yi] if isinstance(yi, value) else 0.0 for yi in row) for row in y)
|
return matrix((grad_dict[yi.net] if isinstance(yi, value) else 0.0 for yi in row) for row in y)
|
||||||
return [grad_dict[yi] for yi in y]
|
return [grad_dict[yi.net] for yi in y]
|
||||||
|
|
|
||||||
|
|
@ -84,34 +84,50 @@ class Net:
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return self.source.node_hash
|
return self.source.node_hash
|
||||||
|
|
||||||
|
def __eq__(self, value: object) -> bool:
|
||||||
|
return isinstance(value, Net) and self.source.node_hash == value.source.node_hash # TODO: change to 64 bit hash
|
||||||
|
|
||||||
class value(Generic[TNum], Net):
|
|
||||||
|
class value(Generic[TNum]):
|
||||||
"""A "value" represents a typed scalar variable. It supports arithmetic and
|
"""A "value" represents a typed scalar variable. It supports arithmetic and
|
||||||
comparison operations.
|
comparison operations.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
dtype (str): Data type of this value.
|
dtype (str): Data type of this value.
|
||||||
"""
|
"""
|
||||||
def __init__(self, source: TNum | Node, dtype: str | None = None):
|
def __init__(self, source: TNum | Net, dtype: str | None = None):
|
||||||
"""Instance a value.
|
"""Instance a value.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
source: A numeric value or Node object.
|
dtype: Data type of this value.
|
||||||
dtype: Data type of this value. Required if source is a Node.
|
net: Reference to the underlying Net in the graph
|
||||||
"""
|
"""
|
||||||
if isinstance(source, Node):
|
if isinstance(source, Net):
|
||||||
self.source = source
|
self.net: Net = source
|
||||||
assert dtype, 'For source type Node a dtype argument is required.'
|
if dtype:
|
||||||
|
assert transl_type(dtype) == source.dtype, f"Type of Net ({source.dtype}) does not match {dtype}"
|
||||||
|
self.dtype: str = dtype
|
||||||
|
else:
|
||||||
|
self.dtype = source.dtype
|
||||||
|
elif dtype == 'int' or dtype == 'bool':
|
||||||
|
new_node = CPConstant(int(source), False)
|
||||||
|
self.net = Net(new_node.dtype, new_node)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
elif isinstance(source, float):
|
elif dtype == 'float':
|
||||||
self.source = CPConstant(source, False)
|
new_node = CPConstant(float(source), False)
|
||||||
self.dtype = 'float'
|
self.net = Net(new_node.dtype, new_node)
|
||||||
elif isinstance(source, bool):
|
self.dtype = dtype
|
||||||
self.source = CPConstant(source, False)
|
elif dtype is None:
|
||||||
|
if isinstance(source, bool):
|
||||||
|
new_node = CPConstant(source, False)
|
||||||
|
self.net = Net(new_node.dtype, new_node)
|
||||||
self.dtype = 'bool'
|
self.dtype = 'bool'
|
||||||
else:
|
else:
|
||||||
self.source = CPConstant(source, False)
|
new_node = CPConstant(source, False)
|
||||||
self.dtype = 'int'
|
self.net = Net(new_node.dtype, new_node)
|
||||||
|
self.dtype = new_node.dtype
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown type: {dtype}')
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
|
def __add__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
|
||||||
|
|
@ -227,28 +243,22 @@ class value(Generic[TNum], Net):
|
||||||
return cast(TCPNum, add_op('sub', [value(0), self]))
|
return cast(TCPNum, add_op('sub', [value(0), self]))
|
||||||
|
|
||||||
def __gt__(self, other: TVarNumb) -> 'value[int]':
|
def __gt__(self, other: TVarNumb) -> 'value[int]':
|
||||||
ret = add_op('gt', [self, other])
|
return add_op('gt', [self, other], dtype='bool')
|
||||||
return value(ret.source, dtype='bool')
|
|
||||||
|
|
||||||
def __lt__(self, other: TVarNumb) -> 'value[int]':
|
def __lt__(self, other: TVarNumb) -> 'value[int]':
|
||||||
ret = add_op('gt', [other, self])
|
return add_op('gt', [other, self], dtype='bool')
|
||||||
return value(ret.source, dtype='bool')
|
|
||||||
|
|
||||||
def __ge__(self, other: TVarNumb) -> 'value[int]':
|
def __ge__(self, other: TVarNumb) -> 'value[int]':
|
||||||
ret = add_op('ge', [self, other])
|
return add_op('ge', [self, other], dtype='bool')
|
||||||
return value(ret.source, dtype='bool')
|
|
||||||
|
|
||||||
def __le__(self, other: TVarNumb) -> 'value[int]':
|
def __le__(self, other: TVarNumb) -> 'value[int]':
|
||||||
ret = add_op('ge', [other, self])
|
return add_op('ge', [other, self], dtype='bool')
|
||||||
return value(ret.source, dtype='bool')
|
|
||||||
|
|
||||||
def __eq__(self, other: TVarNumb) -> 'value[int]': # type: ignore
|
def __eq__(self, other: TVarNumb) -> 'value[int]': # type: ignore
|
||||||
ret = add_op('eq', [self, other], True)
|
return add_op('eq', [self, other], True, dtype='bool')
|
||||||
return value(ret.source, dtype='bool')
|
|
||||||
|
|
||||||
def __ne__(self, other: TVarNumb) -> 'value[int]': # type: ignore
|
def __ne__(self, other: TVarNumb) -> 'value[int]': # type: ignore
|
||||||
ret = add_op('ne', [self, other], True)
|
return add_op('ne', [self, other], True, dtype='bool')
|
||||||
return value(ret.source, dtype='bool')
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __mod__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
|
def __mod__(self: 'value[TNum]', other: 'value[TNum] | TNum') -> 'value[TNum]': ...
|
||||||
|
|
@ -295,7 +305,7 @@ class value(Generic[TNum], Net):
|
||||||
return cp.pow(other, self)
|
return cp.pow(other, self)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return super().__hash__()
|
return id(self)
|
||||||
|
|
||||||
# Bitwise and shift operations for cp[int]
|
# Bitwise and shift operations for cp[int]
|
||||||
def __lshift__(self, other: uniint) -> 'value[int]':
|
def __lshift__(self, other: uniint) -> 'value[int]':
|
||||||
|
|
@ -330,16 +340,26 @@ class value(Generic[TNum], Net):
|
||||||
|
|
||||||
|
|
||||||
class CPConstant(Node):
|
class CPConstant(Node):
|
||||||
def __init__(self, value: int | float, anonymous: bool = True):
|
def __init__(self, value: Any, anonymous: bool = True):
|
||||||
self.dtype, self.value = _get_data_and_dtype(value)
|
if isinstance(value, int):
|
||||||
|
self.value: int | float = value
|
||||||
|
self.dtype = 'int'
|
||||||
|
elif isinstance(value, float):
|
||||||
|
self.value = value
|
||||||
|
self.dtype = 'float'
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Non supported data type: {type(value).__name__}')
|
||||||
|
|
||||||
self.name = 'const_' + self.dtype
|
self.name = 'const_' + self.dtype
|
||||||
self.args = tuple()
|
self.args = tuple()
|
||||||
self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self)
|
self.node_hash = hash(value) ^ hash(self.dtype) if anonymous else id(self)
|
||||||
|
|
||||||
|
|
||||||
class Write(Node):
|
class Write(Node):
|
||||||
def __init__(self, input: Net | int | float):
|
def __init__(self, input: value[Any] | Net | int | float):
|
||||||
if isinstance(input, Net):
|
if isinstance(input, value):
|
||||||
|
net = input.net
|
||||||
|
elif isinstance(input, Net):
|
||||||
net = input
|
net = input
|
||||||
else:
|
else:
|
||||||
node = CPConstant(input)
|
node = CPConstant(input)
|
||||||
|
|
@ -352,15 +372,16 @@ class Write(Node):
|
||||||
|
|
||||||
class Op(Node):
|
class Op(Node):
|
||||||
def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False):
|
def __init__(self, typed_op_name: str, args: Sequence[Net], commutative: bool = False):
|
||||||
assert not args or any(isinstance(t, Net) for t in args), 'args parameter must be of type list[Net]'
|
|
||||||
self.name: str = typed_op_name
|
self.name: str = typed_op_name
|
||||||
self.args: tuple[Net, ...] = tuple(args)
|
self.args: tuple[Net, ...] = tuple(args)
|
||||||
self.node_hash = self.get_node_hash(commutative)
|
self.node_hash = self.get_node_hash(commutative)
|
||||||
|
|
||||||
|
|
||||||
def net_from_value(val: Any) -> value[Any]:
|
def value_from_number(val: Any) -> value[Any]:
|
||||||
vi = CPConstant(val)
|
# Create anonymous constant that can be removed during optimization
|
||||||
return value(vi, vi.dtype)
|
new_node = CPConstant(val)
|
||||||
|
new_net = Net(new_node.dtype, new_node)
|
||||||
|
return value(new_net)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -392,29 +413,22 @@ def iif(expression: Any, true_result: Any, false_result: Any) -> Any:
|
||||||
return (expression != 0) * true_result + (expression == 0) * false_result
|
return (expression != 0) * true_result + (expression == 0) * false_result
|
||||||
|
|
||||||
|
|
||||||
def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False) -> value[Any]:
|
def add_op(op: str, args: list[value[Any] | int | float], commutative: bool = False, dtype: str | None = None) -> value[Any]:
|
||||||
arg_nets = [a if isinstance(a, Net) else net_from_value(a) for a in args]
|
arg_values = [a if isinstance(a, value) else value_from_number(a) for a in args]
|
||||||
|
|
||||||
if commutative:
|
if commutative:
|
||||||
arg_nets = sorted(arg_nets, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
|
arg_values = sorted(arg_values, key=lambda a: a.dtype) # TODO: update the stencil generator to generate only sorted order
|
||||||
|
|
||||||
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_nets])
|
typed_op = '_'.join([op] + [transl_type(a.dtype) for a in arg_values])
|
||||||
|
|
||||||
if typed_op not in generic_sdb.stencil_definitions:
|
if typed_op not in generic_sdb.stencil_definitions:
|
||||||
raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_nets])}")
|
raise NotImplementedError(f"Operation {op} not implemented for {' and '.join([a.dtype for a in arg_values])}")
|
||||||
|
|
||||||
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
result_type = generic_sdb.stencil_definitions[typed_op].split('_')[0]
|
||||||
|
|
||||||
if result_type == 'float':
|
result_net = Net(result_type, Op(typed_op, [av.net for av in arg_values], commutative))
|
||||||
return value[float](Op(typed_op, arg_nets, commutative), result_type)
|
|
||||||
else:
|
|
||||||
return value[int](Op(typed_op, arg_nets, commutative), result_type)
|
|
||||||
|
|
||||||
|
if dtype:
|
||||||
|
result_type = dtype
|
||||||
|
|
||||||
def _get_data_and_dtype(value: Any) -> tuple[str, float | int]:
|
return value(result_net, result_type)
|
||||||
if isinstance(value, int):
|
|
||||||
return ('int', int(value))
|
|
||||||
elif isinstance(value, float):
|
|
||||||
return ('float', float(value))
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Non supported data type: {type(value).__name__}')
|
|
||||||
|
|
|
||||||
|
|
@ -221,6 +221,8 @@ def get_nets(*inputs: Iterable[Iterable[Any]]) -> list[Net]:
|
||||||
for net in el:
|
for net in el:
|
||||||
if isinstance(net, Net):
|
if isinstance(net, Net):
|
||||||
nets.add(net)
|
nets.add(net)
|
||||||
|
else:
|
||||||
|
assert net is None or isinstance(net, Node), net
|
||||||
|
|
||||||
return list(nets)
|
return list(nets)
|
||||||
|
|
||||||
|
|
@ -351,7 +353,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
|
||||||
dw.write_com(binw.Command.FREE_MEMORY)
|
dw.write_com(binw.Command.FREE_MEMORY)
|
||||||
|
|
||||||
# Get all nets/variables associated with heap memory
|
# Get all nets/variables associated with heap memory
|
||||||
variable_list = get_nets([[const_net_list]], extended_output_ops)
|
variable_list = get_nets([const_net_list], extended_output_ops)
|
||||||
|
|
||||||
stencil_names = {node.name for _, node in extended_output_ops}
|
stencil_names = {node.name for _, node in extended_output_ops}
|
||||||
aux_function_names = sdb.get_sub_functions(stencil_names)
|
aux_function_names = sdb.get_sub_functions(stencil_names)
|
||||||
|
|
@ -378,7 +380,7 @@ def compile_to_dag(node_list: Iterable[Node], sdb: stencil_database) -> tuple[bi
|
||||||
dw.write_int(start)
|
dw.write_int(start)
|
||||||
dw.write_int(lengths)
|
dw.write_int(lengths)
|
||||||
dw.write_value(net.source.value, lengths)
|
dw.write_value(net.source.value, lengths)
|
||||||
#print(f'+ {net.dtype} {net.source.value}')
|
print(f'+ {net.dtype} {net.source.value}')
|
||||||
|
|
||||||
# prep auxiliary_functions
|
# prep auxiliary_functions
|
||||||
code_section_layout, func_addr_lookup, aux_func_len = get_aux_func_layout(aux_function_names, sdb)
|
code_section_layout, func_addr_lookup, aux_func_len = get_aux_func_layout(aux_function_names, sdb)
|
||||||
|
|
|
||||||
|
|
@ -66,21 +66,21 @@ class Target():
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
clear_target(self._context)
|
clear_target(self._context)
|
||||||
|
|
||||||
def compile(self, *values: int | float | value[int] | value[float] | Iterable[int | float | value[int] | value[float]]) -> None:
|
def compile(self, *values: int | float | value[Any] | Iterable[int | float | value[Any]]) -> None:
|
||||||
"""Compiles the code to compute the given values.
|
"""Compiles the code to compute the given values.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
values: Values to compute
|
values: Values to compute
|
||||||
"""
|
"""
|
||||||
nodes: list[Node] = []
|
nodes: list[Node] = []
|
||||||
for s in values:
|
for input in values:
|
||||||
if isinstance(s, Iterable):
|
if isinstance(input, Iterable):
|
||||||
for net in s:
|
for v in input:
|
||||||
if isinstance(net, Net):
|
if isinstance(v, value):
|
||||||
nodes.append(Write(net))
|
nodes.append(Write(v))
|
||||||
else:
|
else:
|
||||||
if isinstance(s, Net):
|
if isinstance(input, value):
|
||||||
nodes.append(Write(s))
|
nodes.append(Write(input))
|
||||||
|
|
||||||
dw, self._values = compile_to_dag(nodes, self.sdb)
|
dw, self._values = compile_to_dag(nodes, self.sdb)
|
||||||
dw.write_com(binw.Command.END_COM)
|
dw.write_com(binw.Command.END_COM)
|
||||||
|
|
@ -95,32 +95,33 @@ class Target():
|
||||||
assert coparun(self._context, dw.get_data()) > 0
|
assert coparun(self._context, dw.get_data()) > 0
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def read_value(self, net: value[T]) -> T: ...
|
def read_value(self, variables: value[T]) -> T: ...
|
||||||
@overload
|
@overload
|
||||||
def read_value(self, net: NumLike) -> float | int | bool: ...
|
def read_value(self, variables: NumLike) -> float | int | bool: ...
|
||||||
@overload
|
@overload
|
||||||
def read_value(self, net: Iterable[T | value[T]]) -> list[T]: ...
|
def read_value(self, variables: Iterable[T | value[T]]) -> list[T]: ...
|
||||||
def read_value(self, net: NumLike | value[T] | Iterable[T | value[T]]) -> Any:
|
def read_value(self, variables: NumLike | value[T] | Iterable[T | value[T]]) -> Any:
|
||||||
"""Reads the numeric value of a copapy type.
|
"""Reads the numeric value of a copapy type.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
net: Value or multiple Values to read
|
variables: Variable or multiple variables to read
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numeric value or values
|
Numeric value or values
|
||||||
"""
|
"""
|
||||||
if isinstance(net, Iterable):
|
if isinstance(variables, Iterable):
|
||||||
return [self.read_value(ni) if isinstance(ni, value) else ni for ni in net]
|
return [self.read_value(ni) if isinstance(ni, value) else ni for ni in variables]
|
||||||
|
|
||||||
if isinstance(net, float | int):
|
if isinstance(variables, float | int):
|
||||||
return net
|
return variables
|
||||||
|
|
||||||
assert isinstance(net, Net), "Argument must be a copapy value"
|
assert isinstance(variables, value), "Argument must be a copapy value"
|
||||||
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
|
assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target."
|
||||||
addr, lengths, var_type = self._values[net]
|
addr, lengths, _ = self._values[variables.net]
|
||||||
|
var_type = variables.dtype
|
||||||
assert lengths > 0
|
assert lengths > 0
|
||||||
data = read_data_mem(self._context, addr, lengths)
|
data = read_data_mem(self._context, addr, lengths)
|
||||||
assert data is not None and len(data) == lengths, f"Failed to read value {net}"
|
assert data is not None and len(data) == lengths, f"Failed to read value {variables}"
|
||||||
en = {'little': '<', 'big': '>'}[self.sdb.byteorder]
|
en = {'little': '<', 'big': '>'}[self.sdb.byteorder]
|
||||||
if var_type == 'float':
|
if var_type == 'float':
|
||||||
if lengths == 4:
|
if lengths == 4:
|
||||||
|
|
@ -142,24 +143,24 @@ class Target():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported value type: {var_type}")
|
raise ValueError(f"Unsupported value type: {var_type}")
|
||||||
|
|
||||||
def write_value(self, net: value[Any] | Iterable[value[Any]], value: int | float | Iterable[int | float]) -> None:
|
def write_value(self, variables: value[Any] | Iterable[value[Any]], data: int | float | Iterable[int | float]) -> None:
|
||||||
"""Write to a copapy value on the target.
|
"""Write to a copapy value on the target.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
net: Singe variable or multiple variables to overwrite
|
variables: Singe variable or multiple variables to overwrite
|
||||||
value: Singe value or multiple values to write
|
value: Singe value or multiple values to write
|
||||||
"""
|
"""
|
||||||
if isinstance(net, Iterable):
|
if isinstance(variables, Iterable):
|
||||||
assert isinstance(value, Iterable), "If net is iterable, value must be iterable too"
|
assert isinstance(data, Iterable), "If net is iterable, value must be iterable too"
|
||||||
for ni, vi in zip(net, value):
|
for ni, vi in zip(variables, data):
|
||||||
self.write_value(ni, vi)
|
self.write_value(ni, vi)
|
||||||
return
|
return
|
||||||
|
|
||||||
assert not isinstance(value, Iterable), "If net is not iterable, value must not be iterable"
|
assert not isinstance(data, Iterable), "If net is not iterable, value must not be iterable"
|
||||||
|
|
||||||
assert isinstance(net, Net), "Argument must be a copapy value"
|
assert isinstance(variables, value), "Argument must be a copapy value"
|
||||||
assert net in self._values, f"Value {net} not found. It might not have been compiled for the target."
|
assert variables.net in self._values, f"Value {variables} not found. It might not have been compiled for the target."
|
||||||
addr, lengths, var_type = self._values[net]
|
addr, lengths, var_type = self._values[variables.net]
|
||||||
assert lengths > 0
|
assert lengths > 0
|
||||||
|
|
||||||
dw = binw.data_writer(self.sdb.byteorder)
|
dw = binw.data_writer(self.sdb.byteorder)
|
||||||
|
|
@ -168,17 +169,17 @@ class Target():
|
||||||
dw.write_int(lengths)
|
dw.write_int(lengths)
|
||||||
|
|
||||||
if var_type == 'float':
|
if var_type == 'float':
|
||||||
dw.write_value(float(value), lengths)
|
dw.write_value(float(data), lengths)
|
||||||
elif var_type == 'int' or var_type == 'bool':
|
elif var_type == 'int' or var_type == 'bool':
|
||||||
dw.write_value(int(value), lengths)
|
dw.write_value(int(data), lengths)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported value type: {var_type}")
|
raise ValueError(f"Unsupported value type: {var_type}")
|
||||||
|
|
||||||
dw.write_com(binw.Command.END_COM)
|
dw.write_com(binw.Command.END_COM)
|
||||||
assert coparun(self._context, dw.get_data()) > 0
|
assert coparun(self._context, dw.get_data()) > 0
|
||||||
|
|
||||||
def read_value_remote(self, net: Net) -> None:
|
def read_value_remote(self, variable: value[Any]) -> None:
|
||||||
"""Reads the raw data of a value by the runner."""
|
"""Reads the raw data of a value by the runner."""
|
||||||
dw = binw.data_writer(self.sdb.byteorder)
|
dw = binw.data_writer(self.sdb.byteorder)
|
||||||
add_read_command(dw, self._values, net)
|
add_read_command(dw, self._values, variable.net)
|
||||||
assert coparun(self._context, dw.get_data()) > 0
|
assert coparun(self._context, dw.get_data()) > 0
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,9 @@ def test_compile():
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
#il.write_com(_binwrite.Command.DUMP_CODE)
|
#il.write_com(_binwrite.Command.DUMP_CODE)
|
||||||
|
|
||||||
for net in ret_test:
|
for v in ret_test:
|
||||||
assert isinstance(net, copapy.backend.Net)
|
assert isinstance(v, value)
|
||||||
add_read_command(il, variables, net)
|
add_read_command(il, variables, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ def test_timing_compiler():
|
||||||
|
|
||||||
|
|
||||||
# Get all nets/variables associated with heap memory
|
# Get all nets/variables associated with heap memory
|
||||||
variable_list = get_nets([[const_net_list]], extended_output_ops)
|
variable_list = get_nets([const_net_list], extended_output_ops)
|
||||||
stencil_names = {node.name for _, node in extended_output_ops}
|
stencil_names = {node.name for _, node in extended_output_ops}
|
||||||
|
|
||||||
print(f'-- get_sub_functions: {len(stencil_names)}')
|
print(f'-- get_sub_functions: {len(stencil_names)}')
|
||||||
|
|
|
||||||
|
|
@ -65,9 +65,9 @@ def test_compile():
|
||||||
# run program command
|
# run program command
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
|
|
||||||
for net in ret:
|
for v in ret:
|
||||||
assert isinstance(net, copapy.backend.Net)
|
assert isinstance(v, cp.value)
|
||||||
add_read_command(il, variables, net)
|
add_read_command(il, variables, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,9 +60,9 @@ def test_compile():
|
||||||
# run program command
|
# run program command
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
|
|
||||||
for net in ret:
|
for v in ret:
|
||||||
assert isinstance(net, backend.Net)
|
assert isinstance(v, cp.value)
|
||||||
add_read_command(il, variables, net)
|
add_read_command(il, variables, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,9 +61,9 @@ def test_compile():
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
#il.write_com(_binwrite.Command.DUMP_CODE)
|
#il.write_com(_binwrite.Command.DUMP_CODE)
|
||||||
|
|
||||||
for net in ret:
|
for v in ret:
|
||||||
assert isinstance(net, backend.Net)
|
assert isinstance(v, cp.value)
|
||||||
add_read_command(il, variables, net)
|
add_read_command(il, variables, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,9 @@ def test_compile():
|
||||||
# run program command
|
# run program command
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
|
|
||||||
for net in ret:
|
for v in ret:
|
||||||
assert isinstance(net, Net)
|
assert isinstance(v, value)
|
||||||
add_read_command(il, vars, net)
|
add_read_command(il, vars, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,9 @@ def test_compile_sqrt():
|
||||||
# run program command
|
# run program command
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
|
|
||||||
for net in ret:
|
for v in ret:
|
||||||
assert isinstance(net, copapy.backend.Net)
|
assert isinstance(v, value)
|
||||||
add_read_command(il, variables, net)
|
add_read_command(il, variables, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
@ -62,9 +62,9 @@ def test_compile_log():
|
||||||
# run program command
|
# run program command
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
|
|
||||||
for net in ret:
|
for v in ret:
|
||||||
assert isinstance(net, copapy.backend.Net)
|
assert isinstance(v, value)
|
||||||
add_read_command(il, variables, net)
|
add_read_command(il, variables, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
@ -96,9 +96,9 @@ def test_compile_sin():
|
||||||
# run program command
|
# run program command
|
||||||
il.write_com(_binwrite.Command.RUN_PROG)
|
il.write_com(_binwrite.Command.RUN_PROG)
|
||||||
|
|
||||||
for net in ret:
|
for v in ret:
|
||||||
assert isinstance(net, copapy.backend.Net)
|
assert isinstance(v, copapy.value)
|
||||||
add_read_command(il, variables, net)
|
add_read_command(il, variables, v.net)
|
||||||
|
|
||||||
il.write_com(_binwrite.Command.END_COM)
|
il.write_com(_binwrite.Command.END_COM)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ def test_get_dag_stats():
|
||||||
v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size))
|
v3 = sum((v1 + i + 7) @ v2 for i in range(sum_size))
|
||||||
|
|
||||||
assert isinstance(v3, value)
|
assert isinstance(v3, value)
|
||||||
stat = get_dag_stats([v3])
|
stat = get_dag_stats([v3.net])
|
||||||
print(stat)
|
print(stat)
|
||||||
|
|
||||||
assert stat['const_float'] == 2 * v_size
|
assert stat['const_float'] == 2 * v_size
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ def test_multi_target():
|
||||||
tg1.compile(e)
|
tg1.compile(e)
|
||||||
|
|
||||||
# Patch constant value
|
# Patch constant value
|
||||||
a.source = cp._basic_types.CPConstant(1000.0)
|
a.net.source = cp._basic_types.CPConstant(1000.0)
|
||||||
|
|
||||||
tg2 = cp.Target()
|
tg2 = cp.Target()
|
||||||
tg2.compile(e)
|
tg2.compile(e)
|
||||||
|
|
|
||||||
|
|
@ -107,9 +107,9 @@ def test_compile():
|
||||||
dw.write_com(_binwrite.Command.RUN_PROG)
|
dw.write_com(_binwrite.Command.RUN_PROG)
|
||||||
#dw.write_com(_binwrite.Command.DUMP_CODE)
|
#dw.write_com(_binwrite.Command.DUMP_CODE)
|
||||||
|
|
||||||
for net in ret_test:
|
for v in ret_test:
|
||||||
assert isinstance(net, backend.Net)
|
assert isinstance(v, value)
|
||||||
add_read_command(dw, variables, net)
|
add_read_command(dw, variables, v.net)
|
||||||
|
|
||||||
#dw.write_com(_binwrite.Command.READ_DATA)
|
#dw.write_com(_binwrite.Command.READ_DATA)
|
||||||
#dw.write_int(0)
|
#dw.write_int(0)
|
||||||
|
|
|
||||||
|
|
@ -109,9 +109,9 @@ def test_compile():
|
||||||
dw.write_com(_binwrite.Command.RUN_PROG)
|
dw.write_com(_binwrite.Command.RUN_PROG)
|
||||||
#dw.write_com(_binwrite.Command.DUMP_CODE)
|
#dw.write_com(_binwrite.Command.DUMP_CODE)
|
||||||
|
|
||||||
for net in ret_test:
|
for v in ret_test:
|
||||||
assert isinstance(net, backend.Net)
|
assert isinstance(v, value)
|
||||||
add_read_command(dw, variables, net)
|
add_read_command(dw, variables, v.net)
|
||||||
|
|
||||||
#dw.write_com(_binwrite.Command.READ_DATA)
|
#dw.write_com(_binwrite.Command.READ_DATA)
|
||||||
#dw.write_int(0)
|
#dw.write_int(0)
|
||||||
|
|
@ -148,7 +148,7 @@ def test_compile():
|
||||||
|
|
||||||
for test, ref in zip(ret_test, ret_ref):
|
for test, ref in zip(ret_test, ret_ref):
|
||||||
assert isinstance(test, value)
|
assert isinstance(test, value)
|
||||||
address = variables[test][0]
|
address = variables[test.net][0]
|
||||||
data = result_data[address]
|
data = result_data[address]
|
||||||
if test.dtype == 'int':
|
if test.dtype == 'int':
|
||||||
val = int.from_bytes(data, sdb.byteorder, signed=True)
|
val = int.from_bytes(data, sdb.byteorder, signed=True)
|
||||||
|
|
|
||||||
|
|
@ -120,9 +120,9 @@ def test_compile():
|
||||||
dw.write_com(_binwrite.Command.RUN_PROG)
|
dw.write_com(_binwrite.Command.RUN_PROG)
|
||||||
#dw.write_com(_binwrite.Command.DUMP_CODE)
|
#dw.write_com(_binwrite.Command.DUMP_CODE)
|
||||||
|
|
||||||
for net in ret_test:
|
for v in ret_test:
|
||||||
assert isinstance(net, backend.Net)
|
assert isinstance(v, value)
|
||||||
add_read_command(dw, variables, net)
|
add_read_command(dw, variables, v.net)
|
||||||
|
|
||||||
#dw.write_com(_binwrite.Command.READ_DATA)
|
#dw.write_com(_binwrite.Command.READ_DATA)
|
||||||
#dw.write_int(0)
|
#dw.write_int(0)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue