Source code for dcase_util.containers.containers

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import
from six import iteritems
from builtins import str as text

import os
import sys
import numpy
import copy
import csv
import json
import hashlib

from dcase_util.containers import ContainerMixin, FileMixin
from dcase_util.ui import FancyStringifier, FancyHTMLStringifier
from dcase_util.utils import is_float, is_int, FileFormat


[docs]class ObjectContainer(ContainerMixin, FileMixin): """Container class for object inherited from standard object class.""" valid_formats = [FileFormat.CPICKLE] #: Valid file formats
[docs] def __init__(self, *args, **kwargs): # Run ContainerMixin init ContainerMixin.__init__(self, *args, **kwargs) # Run FileMixin init FileMixin.__init__(self, *args, **kwargs) super(ObjectContainer, self).__init__(*args, **kwargs)
[docs] def load(self, filename=None): """Load file Parameters ---------- filename : str, optional File path Default value filename given to class constructor Raises ------ ImportError: Error if file format specific module cannot be imported IOError: File does not exists or has unknown file format Returns ------- self """ if filename: self.filename = filename self.detect_file_format() self.validate_format() if self.exists(): from dcase_util.files import Serializer if self.format == FileFormat.CPICKLE: self.__dict__.update(Serializer.load_cpickle(filename=self.filename)) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) else: message = '{name}: File does not exists [{file}]'.format(name=self.__class__.__name__, file=self.filename) self.logger.exception(message) raise IOError(message) # Check if after load function is defined, call if found if hasattr(self, '_after_load'): self._after_load() return self
[docs] def save(self, filename=None): """Save file Parameters ---------- filename : str, optional File path Default value filename given to class constructor Raises ------ ImportError: Error if file format specific module cannot be imported IOError: File has unknown file format Returns ------- self """ if filename: self.filename = filename self.detect_file_format() self.validate_format() if self.filename is None or self.filename == '': message = '{name}: Filename is empty [{filename}]'.format( name=self.__class__.__name__, filename=self.filename ) self.logger.exception(message) raise IOError(message) try: from dcase_util.files import Serializer data = self.__dict__ # Check if before save function is defined, call if found if hasattr(self, '_before_save'): data = self._before_save(data) if self.format == FileFormat.CPICKLE: Serializer.save_cpickle(filename=self.filename, data=data) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) except KeyboardInterrupt: os.remove(self.filename) # Delete the file, since most likely it was not saved fully raise # Check if after save function is defined, call if found if hasattr(self, '_after_save'): self._after_save() return self
[docs]class DictContainer(dict, ContainerMixin, FileMixin): """Dictionary container class inherited from standard dict class.""" valid_formats = [FileFormat.YAML, FileFormat.JSON, FileFormat.CPICKLE, FileFormat.MARSHAL, FileFormat.MSGPACK, FileFormat.TXT, FileFormat.CSV, FileFormat.META] #: Valid file formats
[docs] def __init__(self, *args, **kwargs): # Run ContainerMixin init ContainerMixin.__init__(self, *args, **kwargs) # Run FileMixin init FileMixin.__init__(self, *args, **kwargs) super(DictContainer, self).__init__(*args) self.non_hashable_fields = [ '_hash', 'verbose', ] if kwargs.get('non_hashable_fields'): self.non_hashable_fields = list(dict.fromkeys(self.non_hashable_fields + kwargs.get('non_hashable_fields')))
def __getstate__(self): d = super(DictContainer, self).__getstate__() return d def __setstate__(self, d): super(DictContainer, self).__setstate__(d) def to_string(self, ui=None, indent=0): """Get container information in a string Parameters ---------- ui : FancyStringifier or FancyHTMLStringifier Stringifier class Default value FancyStringifier indent : int Amount of indent Default value 0 Returns ------- str """ if ui is None: ui = FancyStringifier() output = '' output += ui.class_name(self.__class__.__name__, indent=indent) + '\n' if hasattr(self, 'filename') and self.filename: output += ui.data(field='filename', value=self.filename, indent=indent) + '\n' output += self._walk_and_show(self, depth=0, ui=ui, indent=indent) return output
[docs] def get_path(self, path, default=None, data=None): """Get value from nested dict with dotted path Parameters ---------- path : str or list of str String in form of "field1.field2.field3" default : str, int, float Default value returned if path does not exists Default value None data : dict, optional Dict for which path search is done, if None given self is used. Used for recursive path search. Default value None Returns ------- """ if data is None: data = self if isinstance(path, list): fields = path elif isinstance(path, str): fields = path.split('.') else: message = '{name}: Unknown type of dotted_path [{path}].'.format( name=self.__class__.__name__, path=path ) self.logger.exception(message) raise ValueError(message) if '*' == fields[0]: # Magic field to return all childes in a list sub_list = [] for key, value in iteritems(data): if len(fields) > 1: sub_list.append( self.get_path( data=value, path='.'.join(fields[1:]), default=default ) ) else: sub_list.append(value) return sub_list else: if len(fields) > 1 and fields[0] in data: # Go deeper return self.get_path( data=data[fields[0]], path='.'.join(fields[1:]), default=default ) elif len(fields) == 1 and fields[0] in data: # We reached to the node if isinstance(data[fields[0]], dict): # Return DictContainer in case of dict return DictContainer(data[fields[0]]) else: return data[fields[0]] else: return default
[docs] def set_path(self, path, new_value, data=None): """Set value in nested dict with dotted path Parameters ---------- path : str or list of str String in form of "field1.field2.field3" new_value : new value to be placed data : dict, optional Dict for which path search is done, if None given self is used. Used for recursive path search. Default value None Returns ------- None """ if data is None: data = self if isinstance(path, list): fields = path elif isinstance(path, str): fields = path.split('.') else: message = '{name}: Unknown type of dotted_path [{path}].'.format( name=self.__class__.__name__, path=path ) self.logger.exception(message) raise ValueError(message) if '*' == fields[0]: # Magic field to set all childes in a list for key, value in iteritems(data): if len(fields) > 1: self.set_path(new_value=new_value, data=value, path='.'.join(fields[1:])) else: data[key] = new_value else: if len(fields) == 1: # We reached to the node data[fields[0]] = new_value else: if fields[0] not in data: data[fields[0]] = {} elif not isinstance(data[fields[0]], dict): # Overwrite path data[fields[0]] = {} self.set_path(new_value=new_value, data=data[fields[0]], path='.'.join(fields[1:]))
[docs] def get_leaf_path_list(self, target_field=None, target_field_startswith=None, target_field_endswith=None): """Get path list to all leaf node in the nested dict. Parameters ---------- target_field : str Field name to filter paths. Default value None target_field_startswith : str Start of field name to filter paths. Default value None target_field_endswith : str End of field name to filter paths. Default value None Returns ------- list Path list """ path_list = list(self._path_generator()) dotted_paths = [] for path in path_list: if target_field is not None and path[-1] == target_field: dotted_paths.append('.'.join(path)) elif target_field_startswith is not None and path[-1].startswith(target_field_startswith): dotted_paths.append('.'.join(path)) elif target_field_endswith is not None and path[-1].endswith(target_field_endswith): dotted_paths.append('.'.join(path)) elif target_field is None and target_field_startswith is None and target_field_endswith is None: dotted_paths.append('.'.join(path)) return sorted(dotted_paths)
[docs] def merge(self, override, target=None): """ Recursive dict merge Parameters ---------- override : dict override parameter dict target : dict target parameter dict Default value None Returns ------- self """ if not target: target = self for k, v in iteritems(override): if k in target and isinstance(target[k], dict) and isinstance(override[k], dict): self.merge(target=target[k], override=override[k]) else: target[k] = override[k] return self
[docs] def get_hash_for_path(self, dotted_path=None): """ Get unique hash string for the data under given path. Parameters ---------- dotted_path : str or list target path Default value None Returns ------- str Unique hash for parameter dict """ if dotted_path: data = self.get_path(path=dotted_path) if data is not None: return self.get_hash(data) else: return None else: return self.get_hash(self)
[docs] def get_hash(self, data=None): """Get unique hash string (md5) for given parameter dict. Parameters ---------- data : dict or list Input parameters Default value None Returns ------- str Unique hash for parameter dict """ if data is None: data = dict(self) md5 = hashlib.md5() md5.update(str(json.dumps(self._clean_for_hashing(copy.deepcopy(data)), sort_keys=True)).encode('utf-8')) return md5.hexdigest()
[docs] def load(self, filename=None, file_format=None): """Load file Parameters ---------- filename : str, optional File path, if None given, filename given to class constructor is used. Default value None file_format : FileFormat, optional Forced file format, use this when there is a miss-match between file extension and file format. Default value None Raises ------ ImportError: Error if file format specific module cannot be imported IOError: File does not exists or has unknown file format Returns ------- self """ if filename: self.filename = filename if not file_format: self.detect_file_format() self.validate_format() if file_format and FileFormat.validate_label(label=file_format): self.format = file_format if self.exists(): # File exits from dcase_util.files import Serializer dict.clear(self) if self.format == FileFormat.YAML or self.format == FileFormat.META: data = Serializer.load_yaml(filename=self.filename) if data: dict.update(self, data) elif self.format == FileFormat.CPICKLE: dict.update(self, Serializer.load_cpickle(filename=self.filename)) elif self.format == FileFormat.MARSHAL: dict.update(self, Serializer.load_marshal(filename=self.filename)) elif self.format == FileFormat.MSGPACK: dict.update(self, Serializer.load_msgpack(filename=self.filename)) elif self.format == FileFormat.JSON: dict.update(self, Serializer.load_json(filename=self.filename)) elif self.format == FileFormat.TXT: with open(self.filename, 'r') as f: lines = f.readlines() dict.update(self, dict(zip(range(0, len(lines)), lines))) elif self.format == FileFormat.CSV: data = {} delimiter = self.delimiter() with open(self.filename, 'r') as f: csv_reader = csv.reader(f, delimiter=delimiter) for row in csv_reader: if len(row) == 2: data[row[0]] = row[1] dict.update(self, data) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) else: message = '{name}: File does not exists [{file}]'.format(name=self.__class__.__name__, file=self.filename) self.logger.exception(message) raise IOError(message) # Call __setstate__ to handle internal variables of the class self.__setstate__(d=self) # Check if after load function is defined, call if found if hasattr(self, '_after_load'): self._after_load() return self
[docs] def save(self, filename=None, file_format=None): """Save file Parameters ---------- filename : str, optional File path, if None given, filename given to class constructor is used. Default value None file_format : FileFormat, optional Forced file format, use this when there is a miss-match between file extension and file format. Default value None Raises ------ ImportError: Error if file format specific module cannot be imported IOError: File has unknown file format Returns ------- self """ if filename: self.filename = filename if not file_format: self.detect_file_format() self.validate_format() if file_format and FileFormat.validate_label(label=file_format): self.format = file_format if self.filename is None or self.filename == '': message = '{name}: Filename is empty [{filename}]'.format( name=self.__class__.__name__, filename=self.filename ) self.logger.exception(message) raise IOError(message) try: from dcase_util.files import Serializer data = dict(self) if hasattr(self, '__getstate__'): data.update(dict(self.__getstate__())) # Check if before save function is defined, call if found if hasattr(self, '_before_save'): data = self._before_save(data) if self.format == FileFormat.YAML: Serializer.save_yaml(filename=self.filename, data=self.get_dump_content(data=data)) elif self.format == FileFormat.CPICKLE: Serializer.save_cpickle(filename=self.filename, data=data) elif self.format == FileFormat.MARSHAL: Serializer.save_marshal(filename=self.filename, data=data) elif self.format == FileFormat.MSGPACK: Serializer.save_msgpack(filename=self.filename, data=data) elif self.format == FileFormat.JSON: Serializer.save_json(filename=self.filename, data=data) elif self.format == FileFormat.TXT: with open(self.filename, "w") as text_file: for line_id in self: text_file.write(self[line_id]) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) except KeyboardInterrupt: os.remove(self.filename) # Delete the file, since most likely it was not saved fully raise # Check if after save function is defined, call if found if hasattr(self, '_after_save'): self._after_save() return self
def get_dump_content(self, data): """Clean internal content for saving Numpy, DictContainer content is converted to standard types Parameters ---------- data : dict Returns ------- dict """ if data: data = dict(data) for k, v in iteritems(data): if isinstance(v, numpy.generic): data[k] = numpy.asscalar(v) elif isinstance(v, DictContainer): data[k] = self.get_dump_content(data=dict(data[k])) elif isinstance(v, dict): data[k] = self.get_dump_content(data=data[k]) return data def _path_generator(self, data=None, current_path_list=None): if data is None: data = self current_path_list = current_path_list[:] if current_path_list else [] if isinstance(data, dict): for key, value in data.items(): if isinstance(value, dict): # We have dict, let's go deeper for d in self._path_generator(value, current_path_list + [key]): yield d else: # We have reached leaf yield current_path_list + [key] def _walk_and_show(self, d, depth=0, ui=None, indent=0): """Recursive dict walk to get string of the content nicely formatted Parameters ---------- d : dict Dict for walking depth : int Depth of walk, string is indented with this Default value 0 ui : FancyStringifier or FancyHTMLStringifier Stringifier class Default value FancyStringifier indent : int Amount of indent Default value 0 Returns ------- str """ if ui is None: ui = FancyStringifier() output = u'' indent_increase = 2 for k, v in sorted(d.items(), key=lambda x: x[0]): k = text(k) if isinstance(v, dict): output += ui.data(field=k, value="dict (%d)" % len(v.keys()), indent=indent + depth * indent_increase) + '\n' output += self._walk_and_show(d=v, depth=depth + 1, ui=ui, indent=indent) else: if isinstance(v, numpy.ndarray): # Numpy array or matrix shape = v.shape if len(shape) == 1: output += ui.data( field=k, value="array (%d)" % (v.shape[0]), indent=indent + depth * indent_increase ) + '\n' elif len(shape) == 2: output += ui.data( field=k, value="matrix (%d, %d)" % (v.shape[0], v.shape[1]), indent=indent + depth * indent_increase ) + '\n' elif len(shape) == 3: output += ui.data( field=k, value="matrix (%d, %d, %d)" % (v.shape[0], v.shape[1], v.shape[2]), indent=indent + depth * indent_increase ) + '\n' elif len(shape) == 4: output += ui.data( field=k, value="matrix (%d, %d, %d, %d)" % (v.shape[0], v.shape[1], v.shape[2], v.shape[3]), indent=indent + depth * indent_increase ) + '\n' elif isinstance(v, list) and len(v) and isinstance(v[0], str): output += ui.data( field=k, value="list (%d)" % len(v), indent=indent + depth * indent_increase ) + '\n' for item_id, item in enumerate(v): output += ui.data( field='[' + text(item_id) + ']', value=item, indent=indent + (depth+1) * indent_increase ) + '\n' elif isinstance(v, list) and len(v) and isinstance(v[0], numpy.ndarray): # List of arrays output += ui.data( field=k, value="list (%d)" % len(v), indent=indent + depth * indent_increase ) + '\n' for item_id, item in enumerate(v): if len(item.shape) == 1: output += ui.data( field='[' + text(item_id) + ']', value="array (%d)" % (item.shape[0]), indent=indent + (depth + 1) * indent_increase ) + '\n' elif len(item.shape) == 2: output += ui.data( field='[' + text(item_id) + ']', value="matrix (%d, %d)" % (item.shape[0], item.shape[1]), indent=indent + (depth + 1) * indent_increase ) + '\n' elif len(item.shape) == 3: output += ui.data( field='[' + text(item_id) + ']', value="matrix (%d, %d, %d)" % (item.shape[0], item.shape[1], item.shape[2]), indent=indent + (depth + 1) * indent_increase ) + '\n' elif len(item.shape) == 4: output += ui.data( field='[' + text(item_id) + ']', value="matrix (%d, %d, %d, %d)" % (item.shape[0], item.shape[1], item.shape[2], item.shape[3]), indent=indent + (depth + 1) * indent_increase ) + '\n' elif isinstance(v, list) and len(v) and isinstance(v[0], dict): output += ui.data( field=k, value="list (%d)" % len(v), indent=indent + depth * indent_increase ) + '\n' for item_id, item in enumerate(v): output += ui.data( field='[' + text(item_id) + ']', value='', indent=indent + (depth + 1) * indent_increase ) + '\n' output += self._walk_and_show(d=item, depth=depth + 2, ui=ui, indent=indent) else: output += ui.data(field=k, value=v, indent=indent + depth*indent_increase) + '\n' return output def _clean_for_hashing(self, data, non_hashable_fields=None): """Recursively remove keys with value set to False, or non hashable fields Parameters ---------- data : dict Data to be processed non_hashable_fields : list List of fields to be removed. Default value None Returns ------- str """ if non_hashable_fields is None and hasattr(self, 'non_hashable_fields'): non_hashable_fields = self.non_hashable_fields elif non_hashable_fields is None: non_hashable_fields = [] if data: if 'enable' in data and not data['enable']: return { 'enable': False, } else: if isinstance(data, dict): for key in list(data.keys()): value = data[key] if isinstance(value, bool) and value is False: # Remove fields marked False del data[key] elif key in non_hashable_fields: # Remove fields marked in non_hashable_fields list del data[key] elif isinstance(value, dict): if 'enable' in value and not value['enable']: # Remove dict block which is disabled del data[key] else: # Proceed recursively data[key] = self._clean_for_hashing(value) elif hasattr(value, 'get_config'): data[key] = value.get_config() return data else: return data else: return data def filter(self, data=None, excluded_key_prefix='_'): """Filter nested dict Parameters ---------- data : dict, optional Dict to filter, if None given self is used. Default value None excluded_key_prefix : str Key prefix to be excluded Default value '_' Returns ------- DictContainer """ if data is None: data = copy.deepcopy(self) for key in list(data.keys()): if isinstance(data[key], dict): data[key] = self.filter( data=data[key], excluded_key_prefix=excluded_key_prefix ) else: if excluded_key_prefix and key.startswith(excluded_key_prefix): del data[key] return data
[docs]class ListContainer(list, ContainerMixin, FileMixin): """List container class inherited from standard list class.""" valid_formats = [FileFormat.TXT, FileFormat.CPICKLE] #: Valid file formats
[docs] def __init__(self, *args, **kwargs): """Constructor Parameters ---------- filename : str, optional File path """ # Run ContainerMixin init ContainerMixin.__init__(self, *args, **kwargs) # Run FileMixin init FileMixin.__init__(self, *args, **kwargs) # Run list init list.__init__(self, *args)
def to_string(self, ui=None, indent=0): """Get container information in a string Parameters ---------- ui : FancyStringifier or FancyHTMLStringifier Stringifier class Default value FancyStringifier indent : int Amount of indent Default value 0 Returns ------- str """ if ui is None: ui = FancyStringifier() output = '' output += ui.class_name(self.__class__.__name__, indent=indent) + '\n' if hasattr(self, 'filename') and self.filename: output += ui.data(field='filename', value=self.filename, indent=indent) + '\n' for item_id, item in enumerate(self): output += ui.data(field='[' + str(item_id) + ']', value=str(item), indent=indent) + '\n' return output def __getstate__(self): return super(ListContainer, self).__getstate__() def __setstate__(self, d): super(ListContainer, self).__setstate__(d)
[docs] def update(self, data): """Replace content with given list Parameters ---------- data : list New content Returns ------- self """ list.__init__(self, data) return self
[docs] def load(self, filename=None, headers=None, file_format=None): """Load file Parameters ---------- filename : str, optional File path Default value filename given to class constructor. Default value None headers : list of str, optional List of column names. Default value None file_format : FileFormat, optional Forced file format, use this when there is a miss-match between file extension and file format. Default value None Raises ------ IOError: File does not exists or has unknown file format Returns ------- self """ if filename: self.filename = filename if not file_format: self.detect_file_format() self.validate_format() if file_format and FileFormat.validate_label(label=file_format): self.format = file_format if self.exists(): from dcase_util.files import Serializer if self.format == FileFormat.TXT: with open(self.filename, 'r') as f: lines = f.readlines() # Remove line breaks for i in range(0, len(lines)): lines[i] = lines[i].replace('\r\n', '').replace('\n', '') list.__init__(self, lines) elif self.format == FileFormat.CPICKLE: list.__init__(self, Serializer.load_cpickle(filename=self.filename)) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) else: message = '{name}: File does not exists [{file}]'.format(name=self.__class__.__name__, file=self.filename) self.logger.exception(message) raise IOError(message) # Check if after load function is defined, call if found if hasattr(self, '_after_load'): self._after_load() return self
[docs] def save(self, filename=None, file_format=None): """Save file Parameters ---------- filename : str, optional File path, if None given, filename given to class constructor is used. Default value None file_format : FileFormat, optional Forced file format, use this when there is a miss-match between file extension and file format. Default value None Raises ------ IOError: File has unknown file format Returns ------- self """ if filename: self.filename = filename if not file_format: self.detect_file_format() self.validate_format() if file_format and FileFormat.validate_label(label=file_format): self.format = file_format if self.filename is None or self.filename == '': message = '{name}: Filename is empty [{filename}]'.format( name=self.__class__.__name__, filename=self.filename ) self.logger.exception(message) raise IOError(message) try: from dcase_util.files import Serializer if self.format == FileFormat.TXT: with open(self.filename, "w") as text_file: for line in self: text_file.write(str(line)+'\n') elif self.format == FileFormat.CPICKLE: Serializer.save_cpickle(filename=self.filename, data=self) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) except KeyboardInterrupt: os.remove(self.filename) # Delete the file, since most likely it was not saved fully raise # Check if after save function is defined, call if found if hasattr(self, '_after_save'): self._after_save() return self
def get_dump_content(self, data): """Clean internal content for saving Numpy, DictContainer content is converted to standard types Parameters ---------- data : dict Returns ------- dict """ if data: data = dict(data) for k, v in iteritems(data): if isinstance(v, numpy.generic): data[k] = numpy.asscalar(v) elif isinstance(v, DictContainer): data[k] = self.get_dump_content(data=dict(data[k])) elif isinstance(v, dict): data[k] = self.get_dump_content(data=data[k]) return data
[docs]class ListDictContainer(ListContainer): """List of dictionaries container class inherited from standard list class.""" valid_formats = [FileFormat.CSV, FileFormat.YAML, FileFormat.CPICKLE] #: Valid file formats
[docs] def __init__(self, *args, **kwargs): """Constructor Parameters ---------- filename : str, optional File path """ # Run ContainerMixin init ContainerMixin.__init__(self, *args, **kwargs) # Run FileMixin init FileMixin.__init__(self, *args, **kwargs) # Run list init list.__init__(self, *args) # Convert list items to DictContainers for item_id, item in enumerate(self): self[item_id] = DictContainer(item)
def to_string(self, ui=None, indent=0): """Get container information in a string Parameters ---------- ui : FancyStringifier or FancyHTMLStringifier Stringifier class Default value FancyStringifier indent : int Amount of indent Default value 0 Returns ------- str """ if ui is None: ui = FancyStringifier() output = '' output += ui.class_name(self.__class__.__name__, indent=indent) + '\n' if hasattr(self, 'filename') and self.filename: output += ui.data(field='filename', value=self.filename) + '\n' for item_id, item in enumerate(self): output += ui.line(field='[' + str(item_id) + ']', indent=indent) + '\n' output += DictContainer(item).to_string(ui=ui, indent=indent + 1) + '\n' return output
[docs] def search(self, key, value): """Search in the list of dictionaries Parameters ---------- key : str Dict key for the search value : Value for the key to match Returns ------- dict or None """ for element in self: if element.get(key) == value: return element return None
[docs] def load(self, filename=None, fields=None, csv_header=True, file_format=None, delimiter=None, convert_numeric_fields=True): """Load file Parameters ---------- filename : str, optional File path, if None given, filename given to class constructor is used. Default value None fields : list of str, optional List of column names Default value None csv_header : bool, optional Read field names from first line (header). Used only for CSV formatted files. Default value True file_format : FileFormat, optional Forced file format, use this when there is a miss-match between file extension and file format. Default value None delimiter : str, optional Forced data delimiter for csv format. If None given, automatic delimiter sniffer used. Use this when sniffer does not work. Default value None convert_numeric_fields : bool, optional Convert int and float fields to correct type. Default value True Raises ------ IOError: File does not exists or has unknown file format ValueError: No fields or csv_header set for CSV formatted file. Returns ------- self """ if filename: self.filename = filename if not file_format: self.detect_file_format() self.validate_format() if file_format and FileFormat.validate_label(label=file_format): self.format = file_format if self.exists(): from dcase_util.files import Serializer if self.format == FileFormat.CSV: if fields is None and csv_header is None: message = '{name}: Parameters fields or csv_header has to be set for CSV files.'.format( name=self.__class__.__name__ ) self.logger.exception(message) raise ValueError(message) data = [] if not delimiter: delimiter = self.delimiter() with open(self.filename, 'r') as f: csv_reader = csv.reader(f, delimiter=delimiter) if csv_header: csv_fields = next(csv_reader) if fields is None: fields = csv_fields for row in csv_reader: if row: if convert_numeric_fields: for cell_id, cell_data in enumerate(row): if is_int(cell_data): row[cell_id] = int(cell_data) elif is_float(cell_data): row[cell_id] = float(cell_data) data.append(dict(zip(fields, row))) list.__init__(self, data) elif self.format == FileFormat.YAML: data = Serializer.load_yaml(filename=self.filename) if isinstance(data, list): list.__init__(self, data) else: message = '{name}: YAML data is not in list format.'.format(name=self.__class__.__name__) self.logger.exception(message) raise ImportError(message) elif self.format == FileFormat.CPICKLE: list.__init__(self, Serializer.load_cpickle(filename=self.filename)) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) else: message = '{name}: File does not exists [{file}]'.format(name=self.__class__.__name__, file=self.filename) self.logger.exception(message) raise IOError(message) # Check if after load function is defined, call if found if hasattr(self, '_after_load'): self._after_load() return self
[docs] def save(self, filename=None, fields=None, csv_header=True, file_format=None, delimiter=','): """Save file Parameters ---------- filename : str, optional File path, if None given, filename given to class constructor is used. Default value None fields : list of str Fields in correct order, if none given all field in alphabetical order will be outputted Default value None csv_header : bool In case of CSV formatted file, first line will contain field names. Names are taken from fields parameter. Default value True file_format : FileFormat, optional Forced file format, use this when there is a miss-match between file extension and file format. Default value None delimiter : str Delimiter to be used when saving data Default value ',' Raises ------ IOError: File has unknown file format Returns ------- self """ if filename: self.filename = filename if not file_format: self.detect_file_format() self.validate_format() if file_format and FileFormat.validate_label(label=file_format): self.format = file_format if self.filename is None or self.filename == '': message = '{name}: Filename is empty [{filename}]'.format( name=self.__class__.__name__, filename=self.filename ) self.logger.exception(message) raise IOError(message) try: from dcase_util.files import Serializer if self.format == FileFormat.YAML: data = copy.deepcopy(list(self)) for item_id, item in enumerate(data): data[item_id] = self.get_dump_content(data=item) Serializer.save_yaml(filename=self.filename, data=data) elif self.format == FileFormat.CSV: if fields is None: fields = set() for item in self: fields.update(list(item.keys())) fields = sorted(list(fields)) # Make sure writing is using correct line endings to avoid extra empty lines if sys.version_info[0] == 2: csv_file = open(self.filename, 'wb') elif sys.version_info[0] >= 3: csv_file = open(self.filename, 'w', newline='') try: csv_writer = csv.writer(csv_file, delimiter=delimiter) if csv_header: csv_writer.writerow(fields) for item in self: item_values = [] for field in fields: item_values.append(item[field]) csv_writer.writerow(item_values) finally: csv_file.close() elif self.format == FileFormat.CPICKLE: Serializer.save_cpickle(filename=self.filename, data=self) else: message = '{name}: Unknown format [{format}]'.format(name=self.__class__.__name__, format=self.filename) self.logger.exception(message) raise IOError(message) except KeyboardInterrupt: os.remove(self.filename) # Delete the file, since most likely it was not saved fully raise # Check if after save function is defined, call if found if hasattr(self, '_after_save'): self._after_save() return self
[docs] def get_field(self, field_name, skip_items_without_field=True): """Get all data from field. Parameters ---------- field_name : str Dict key for the search skip_items_without_field : bool Skip items without field, if true None inserted to the output. Default value True Returns ------- list """ data = [] for item in self: if field_name in item: data.append(item[field_name]) elif not skip_items_without_field: data.append(None) return data
def get_field_unique(self, field_name): """Get unique data from field. Parameters ---------- field_name : str Dict key for the search Returns ------- list """ all_values = self.get_field(field_name=field_name) return sorted(list(set(all_values))) def remove_field(self, field_name): """Remove field from data items Parameters ---------- field_name : str Field name Returns ------- self """ for item in self: if field_name in item: del item[field_name] return self def filter(self, case_insensitive_fields=True, **kwargs): """Filter content based on field values. Parameters ---------- case_insensitive_fields : bool Use case insensitive fields for filtering Default value True kwargs Use filtered field name and parameter name, and target value for the field as parameter value. Underscore is parameter name is replace with whitespace when matching with field names. If value can be a single value or list of values. Returns ------- ListDictContainer """ filter_fields = {} for field in kwargs: if case_insensitive_fields: filter_fields[field.lower()] = kwargs[field] else: filter_fields[field] = kwargs[field] data = [] for item in self: matched = [] item_field_list = list(item.keys()) item_field_map = {} for field in item_field_list: if case_insensitive_fields: item_field_map[field.lower()] = field else: item_field_map[field] = field for condition_field in filter_fields: condition_field_alternative = condition_field.replace('_', ' ') if condition_field in item_field_map: if isinstance(filter_fields[condition_field], list): if item[item_field_map[condition_field]] in filter_fields[condition_field]: matched.append(True) else: matched.append(False) else: if item[item_field_map[condition_field]] == filter_fields[condition_field]: matched.append(True) else: matched.append(False) elif condition_field_alternative in item_field_map: if isinstance(filter_fields[condition_field], list): if item[item_field_map[condition_field_alternative]] in filter_fields[condition_field]: matched.append(True) else: matched.append(False) else: if item[item_field_map[condition_field_alternative]] == filter_fields[condition_field]: matched.append(True) else: matched.append(False) if all(matched): data.append(copy.deepcopy(item)) return ListDictContainer(data)
[docs]class RepositoryContainer(DictContainer): """Container class for repository, inherited from DictContainer.""" valid_formats = [FileFormat.CPICKLE] #: Valid file formats
[docs]class TextContainer(ListContainer): """Container class for text, inherited from ListContainer.""" valid_formats = [FileFormat.TXT] #: Valid file formats