Source code for models.node.processing.encoder.singletoonehot
import abc
from typing import Final, Dict
from models.exception.non_compatible_data import NonCompatibleData
from models.framework_data import FrameworkData
from models.node.processing.processing_node import ProcessingNode
from typing import List
[docs]class SingleToOneHot(ProcessingNode):
""" Converts a single channel encoding (Ordinal encoding) signal to a one-hot encoding.
One-hot encoded signals are signals where each label is represented by a vector of the same length as the number
of labels, where the label is represented by a 1 at the index of the label and 0 everywhere else.
A Single channel encoded (Ordinal encoding) signal is a type of signal where each label is represented by a number,
where the label is represented by the index of the label.
This node converts a single channel labels to a one-hot encoded label. The single label count starts at 1, so the
label 1 is represented by the channel 1, the label 2 by the channel 2, etc. There is no label 0.
Attributes:
_MODULE_NAME (`str`): The name of the module (in his case ``node.processing.encoder.singletoonehot``)
INPUT_MAIN (`str`): The name of the main input signal (in this case ``main``)
OUTPUT_MAIN (`str`): The name of the main output signal (in this case ``main``)
configuration.json usage:
**module** (*str*): The name of the module (``node.processing.encoder``)\n
**type** (*str*): The type of the node (``SingleToOneHot``)\n
**buffer_options** (*dict*): Buffer options.\n
**clear_output_buffer_on_data_input** (*bool*): Whether to clear the output buffer when new data is inserted in the input buffer.\n
**clear_input_buffer_after_process** (*bool*): Whether to clear the input buffer after processing.\n
**clear_output_buffer_after_process** (*bool*): Whether to clear the output buffer after processing.\n
"""
_MODULE_NAME: Final[str] = 'node.processing.encoder.singletoonehot'
INPUT_MAIN: Final[str] = 'main'
OUTPUT_MAIN: Final[str] = 'main'
@abc.abstractmethod
def _initialize_parameter_fields(self, parameters: dict):
""" Initializes the parameter fields of this node.
:param parameters: The parameters passed to this node.
:type parameters: dict
"""
super()._initialize_parameter_fields(parameters)
self.labels = parameters['labels']
def _is_next_node_call_enabled(self) -> bool:
""" Returns whether the next node call is enabled. The next node call is enabled if the input buffer is not empty.
"""
return self._output_buffer[self.OUTPUT_MAIN].get_data_count() > 0
def _is_processing_condition_satisfied(self) -> bool:
""" Returns whether the processing condition is satisfied. The processing condition is satisfied if the input buffer is not empty.
"""
return self._input_buffer[self.INPUT_MAIN].get_data_count() > 0
def _process(self, data: Dict[str, FrameworkData]) -> Dict[str, FrameworkData]:
""" This method encodes the data labels that before was a single channel label (Ordinal encoding) to a one-hot encoded label. It does this
comparing the index of the channel to the label and setting the output channel to 1 if the index is equal to the label and 0 otherwise. It
does this for each data point in the dataset.
:param data: The data to process.
:type data: dict
:return: The processed data.
:rtype: dict
"""
self.print('encoding...')
raw_data = data[self.INPUT_MAIN]
if not raw_data.is_1d():
raise NonCompatibleData(module=self._MODULE_NAME,name=self.name, cause='provided_data_is_multichannel')
encoded_data: FrameworkData = FrameworkData(sampling_frequency_hz=raw_data.sampling_frequency,
channels=self.labels)
for data_entry in raw_data.get_data_single_channel():
for channel_index, channel in enumerate(self.labels):
encoded_value = 1 if channel_index == data_entry-1 else 0
encoded_data.input_data_on_channel([encoded_value], channel)
self.print('encoded!')
return {
self.OUTPUT_MAIN: encoded_data
}
def _get_inputs(self) -> List[str]:
""" Returns the input fields of this node.
"""
return [
self.INPUT_MAIN
]
def _get_outputs(self) -> List[str]:
""" Returns the output fields of this node.
"""
return [
self.OUTPUT_MAIN
]