Source code for dcase_util.containers.mixins
# !/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import
import os
import sys
import logging
import csv
import zipfile
import tarfile
from dcase_util.utils import setup_logging, FileFormat, Path, get_file_hash, is_jupyter
from dcase_util.ui import FancyLogger, FancyStringifier, FancyHTMLStringifier
[docs]class ContainerMixin(object):
"""Container mixin to give class basic container methods."""
[docs] def __init__(self, *args, **kwargs):
if not hasattr(self, 'ui'):
self.ui = FancyLogger()
# Setup progress bar
if not hasattr(self, 'log_progress'):
self.log_progress = kwargs.get('log_progress', True)
if not hasattr(self, 'disable_progress_bar'):
self.disable_progress_bar = kwargs.get('disable_progress_bar', False)
if not hasattr(self, 'use_ascii_progress_bar'):
self.use_ascii_progress_bar = kwargs.get('use_ascii_progress_bar', True)
def __getstate__(self):
return {}
def __setstate__(self, d):
self.ui = FancyLogger()
def __str__(self):
return self.to_string()
@property
def logger(self):
"""Logger instance"""
logger = logging.getLogger(__name__)
if not logger.handlers:
setup_logging()
return logger
[docs] def show(self, mode='auto', indent=0, visualize=False):
"""Print container content
If called inside Jupyter notebook HTML formatted version is shown.
Parameters
----------
mode : str
Output type, possible values ['auto', 'print', 'html']. 'html' will work only in Jupyter notebook
Default value 'auto'
indent : int
Amount of indent
Default value 0
visualize : bool
Visualize container data if class has plot method
Default value False
Returns
-------
Nothing
"""
if mode == 'auto':
if is_jupyter():
mode = 'html'
else:
mode = 'print'
if mode not in ['html', 'print']:
# Unknown mode given
message = '{name}: Unknown mode [{mode}]'.format(name=self.__class__.__name__, mode=mode)
self.logger.exception(message)
raise ValueError(message)
if mode == 'html':
from IPython.core.display import display, HTML
display(
HTML(
self.to_html(indent=indent)
)
)
if visualize and hasattr(self, 'plot'):
# If class has plot method use it to visualize the content
self.plot()
elif mode == 'print':
print(self.to_string(indent=indent))
def to_string(self, ui=None, indent=0):
"""Get container information in a string
Parameters
----------
ui : FancyStringifier or FancyHTMLStringifier
Stringifier class
Default value FancyStringifier
indent : int
Amount of indent
Default value 0
Returns
-------
str
"""
if ui is None:
ui = FancyStringifier()
output = ''
output += ui.class_name(self.__class__.__name__, indent=indent) + '\n'
if hasattr(self, 'filename') and self.filename:
output += ui.data(field='filename', value=self.filename, indent=indent) + '\n'
return output
def to_html(self, indent=0):
"""Get container information in a HTML formatted string
Parameters
----------
indent : int
Amount of indent
Default value 0
Returns
-------
str
"""
return self.to_string(ui=FancyHTMLStringifier(), indent=indent)
[docs] def log(self, level='info'):
"""Log container content
Parameters
----------
level : str
Logging level, possible values [info, debug, warn, warning, error, critical]
Returns
-------
Nothing
"""
self.ui.line(self.__str__(), level=level)
[docs]class FileMixin(object):
"""File mixin to give class methods to load and store content."""
[docs] def __init__(self, *args, **kwargs):
if not hasattr(self, 'ui'):
self.ui = FancyLogger()
if kwargs.get('valid_formats', None):
self.valid_formats = kwargs.get('valid_formats', None)
elif not hasattr(self, 'valid_formats'):
self.valid_formats = []
if not hasattr(self, 'filename'):
if kwargs.get('filename', None):
# Set variables
self.filename = kwargs.get('filename', None)
else:
self.filename = None
if not hasattr(self, 'format'):
self.format = None
if hasattr(self, 'filename') and self.filename is not None:
self.detect_file_format()
self.validate_format()
# Setup progress bar
if not hasattr(self, 'log_progress'):
self.log_progress = kwargs.get('log_progress', True)
if not hasattr(self, 'disable_progress_bar'):
self.disable_progress_bar = kwargs.get('disable_progress_bar', False)
if not hasattr(self, 'use_ascii_progress_bar'):
self.use_ascii_progress_bar = kwargs.get('use_ascii_progress_bar', True)
def __getstate__(self):
return {}
def __setstate__(self, d):
self.ui = FancyLogger()
@property
def logger(self):
logger = logging.getLogger(__name__)
if not logger.handlers:
setup_logging()
return logger
@property
def md5(self):
"""Checksum for file.
Returns
-------
str
"""
if self.exists():
return get_file_hash(filename=self.filename)
else:
return None
@property
def bytes(self):
"""File size in bytes
Returns
-------
int
"""
if self.exists():
return os.path.getsize(self.filename)
else:
return None
[docs] def get_file_information(self):
"""Get file information, filename
Returns
-------
str
"""
if self.filename:
return 'Filename: ['+self.filename+']'
else:
return ''
[docs] def detect_file_format(self, filename=None):
"""Detect file format from extension
Parameters
----------
filename : str
filename
Default value None
Raises
------
IOError:
Unknown file format
Returns
-------
str
format tag
"""
if filename is None:
filename = self.filename
file_format = FileFormat.detect(filename=filename)
if file_format is None:
# Unknown format
message = '{name}: File format cannot be detected for file [{filename}] '.format(
name=self.__class__.__name__,
filename=filename
)
if self.logger:
self.logger.exception(message)
raise IOError(message)
self.format = file_format
return self
[docs] def validate_format(self):
"""Validate file format
Raises
------
IOError:
Unknown file format
Returns
-------
bool
"""
if self.valid_formats:
# Validate only if valid format has been given
if not hasattr(self, 'format') or not self.format:
# Detect format if defined yet
self.detect_file_format()
if self.format and self.format in self.valid_formats:
# Format found and valid formats defined, validate format
return True
else:
message = '{name}: Unknown format [{format}] for file [{file}]'.format(
name=self.__class__.__name__,
format=os.path.splitext(self.filename)[-1],
file=self.filename
)
if self.logger:
self.logger.exception(message)
raise IOError(message)
else:
return True
[docs] def exists(self):
"""Checks that file exists
Returns
-------
bool
"""
return os.path.isfile(self.filename)
[docs] def empty(self):
"""Check if file is empty
Returns
-------
bool
"""
if len(self) == 0:
return True
else:
return False
[docs] def delimiter(self, exclude_delimiters=None):
"""Use csv.sniffer to guess delimiter for CSV file
Parameters
----------
exclude_delimiters : list of str
List of delimiter to be excluded
Default value None
Returns
-------
str
"""
if exclude_delimiters is None:
exclude_delimiters = []
sniffer = csv.Sniffer()
valid_delimiters = ['\t', ',', ';', ' ']
if exclude_delimiters:
# Remove excluded delimiters from the list
valid_delimiters = list(set(valid_delimiters) - set(exclude_delimiters))
delimiter = '\t'
with open(self.filename, 'rt') as f1:
try:
example_content = f1.read(1024)
dialect = sniffer.sniff(example_content)
if hasattr(dialect, '_delimiter'):
if dialect._delimiter in valid_delimiters:
delimiter = dialect._delimiter
elif hasattr(dialect, 'delimiter'):
if dialect.delimiter in valid_delimiters:
delimiter = dialect.delimiter
else:
# Fall back to default
delimiter = '\t'
except csv.Error:
# Fall back to default
delimiter = '\t'
return delimiter
[docs] def is_package(self, filename=None):
"""Determine if the file is compressed package.
Parameters
----------
filename : str
filename
Default value None
Returns
-------
bool
"""
if filename is None and hasattr(self, 'filename'):
filename = self.filename
if filename:
self.detect_file_format(filename)
if self.format == FileFormat.ZIP or self.format == FileFormat.TAR:
return True
else:
return False
[docs]class PackageMixin(object):
"""Package mixin to give class basic methods to handle compressed file packages."""
[docs] def __init__(self, *args, **kwargs):
if not hasattr(self, 'ui'):
self.ui = FancyLogger()
# Setup progress bar
if not hasattr(self, 'log_progress'):
self.log_progress = kwargs.get('log_progress', True)
if not hasattr(self, 'disable_progress_bar'):
self.disable_progress_bar = kwargs.get('disable_progress_bar', False)
if not hasattr(self, 'use_ascii_progress_bar'):
self.use_ascii_progress_bar = kwargs.get('use_ascii_progress_bar', True)
@property
def logger(self):
logger = logging.getLogger(__name__)
if not logger.handlers:
setup_logging()
return logger
@property
def package_password(self):
"""Package password
Returns
-------
str or None
"""
if 'package_password' in self:
return self['package_password']
else:
return None
@package_password.setter
def package_password(self, value):
self['package_password'] = value
@property
def md5(self):
"""Checksum for file.
Returns
-------
str
"""
if self.exists():
return get_file_hash(filename=self.filename)
else:
return None
@property
def bytes(self):
"""File size in bytes
Returns
-------
int
"""
if self.exists():
return os.path.getsize(self.filename)
else:
return None
[docs] def extract(self, target_path=None, overwrite=False, omit_first_level=False):
"""Extract the package. Supports Zip and Tar packages.
Parameters
----------
target_path : str
Path to extract the package content. If none given, package is extracted in the same path than package.
Default value None
overwrite : bool
Overwrite existing files.
Default value False
omit_first_level : bool
Omit first directory level.
Default value True
Returns
-------
self
"""
if is_jupyter():
from tqdm import tqdm_notebook as tqdm
else:
from tqdm import tqdm
if target_path is None:
target_path = os.path.split(self.filename)[0]
Path(target_path).create()
offset = 0
if self.format == FileFormat.ZIP:
with zipfile.ZipFile(self.filename, "r") as z:
if omit_first_level:
parts = []
for name in z.namelist():
if not name.endswith('/'):
parts.append(name.split('/')[:-1])
prefix = os.path.commonprefix(parts) or ''
if prefix:
if len(prefix) > 1:
prefix_ = list()
prefix_.append(prefix[0])
prefix = prefix_
prefix = '/'.join(prefix) + '/'
offset = len(prefix)
# Start extraction
members = z.infolist()
file_count = 1
progress = tqdm(
members,
desc="{0: <25s}".format('Extract'),
file=sys.stdout,
leave=False,
disable=self.disable_progress_bar,
ascii=self.use_ascii_progress_bar
)
for i, member in enumerate(progress):
if self.disable_progress_bar:
self.logger.info(' {title:<15s} [{item_id:d}/{total:d}] {file:<30s}'.format(
title='Extract ',
item_id=i,
total=len(progress),
file=member.filename)
)
if not omit_first_level or len(member.filename) > offset:
if omit_first_level:
member.filename = member.filename[offset:]
progress.set_description("{0: >35s}".format(member.filename.split('/')[-1]))
progress.update()
if not os.path.isfile(os.path.join(target_path, member.filename)) or overwrite:
try:
if hasattr(self, 'package_password') and self.package_password:
z.extract(
member=member,
path=target_path,
pwd=self.package_password
)
else:
z.extract(
member=member,
path=target_path
)
except KeyboardInterrupt:
# Delete latest file, since most likely it was not extracted fully
os.remove(os.path.join(target_path, member.filename))
# Quit
sys.exit()
file_count += 1
elif self.format == FileFormat.TAR:
tar = tarfile.open(self.filename, "r:gz")
progress = tqdm(
tar,
desc="{0: <25s}".format('Extract'),
file=sys.stdout,
leave=False,
disable=self.disable_progress_bar,
ascii=self.use_ascii_progress_bar
)
for i, tar_info in enumerate(progress):
if self.disable_progress_bar:
self.logger.info(' {title:<15s} [{item_id:d}/{total:d}] {file:<30s}'.format(
title='Extract ',
item_id=i,
total=len(progress),
file=tar_info.name)
)
if not os.path.isfile(os.path.join(target_path, tar_info.name)) or overwrite:
tar.extract(tar_info, target_path)
tar.members = []
tar.close()
return self
def compress(self, filename=None, path=None, file_list=None, size_limit=None):
"""Compress the package. Supports Zip and Tar packages.
Parameters
----------
filename : str
Filename for the package. If None given, one given to class initializer is used.
Default value None
path : str
Path get files if file_list is not set. Files are collected recursively.
Default value None
file_list : list of dict
List of files to be included to the package.
Item format {'source': 'file1.txt', 'target': 'folder1/file1.txt'}.
Default value None
size_limit : int
Size limit in bytes.
Default value None
Returns
-------
list of str
Filenames of created packages
"""
if is_jupyter():
from tqdm import tqdm_notebook as tqdm
else:
from tqdm import tqdm
if filename is not None:
self.filename = filename
self.detect_file_format()
self.validate_format()
if path is not None and file_list is None:
files = Path(path=path).file_list(recursive=True)
file_list = []
for filename in files:
file_list.append(
{
'source': filename,
'target': os.path.relpath(filename)
}
)
package_filenames = []
total_uncompressed_size = 0
for item in file_list:
total_uncompressed_size += os.path.getsize(item['source'])
if size_limit is None or total_uncompressed_size < size_limit:
package = None
if self.format == FileFormat.ZIP:
package = zipfile.ZipFile(
file=self.filename,
mode='w'
)
elif self.format == FileFormat.TAR:
package = tarfile.open(
name=self.filename,
mode='w:gz'
)
package_filenames.append(self.filename)
size_uncompressed = 0
for item in file_list:
if os.path.exists(item['source']):
if self.format == FileFormat.ZIP:
package.write(
filename=item['source'],
arcname=os.path.relpath(item['target']),
compress_type=zipfile.ZIP_DEFLATED
)
file_info = package.getinfo(os.path.relpath(item['target']))
size_uncompressed += file_info.file_size
elif self.format == FileFormat.TAR:
package.add(
name=item['source'],
arcname=os.path.relpath(item['target'])
)
file_info = package.gettarinfo(
name=item['source'],
arcname=os.path.relpath(item['target'])
)
size_uncompressed += file_info.size
else:
package.close()
message = '{name}: Non-existing file [{filename}] detected while compressing a package [{package}]'.format(
name=self.__class__.__name__,
filename=item['source'],
package=self.filename
)
if self.logger:
self.logger.exception(message)
raise IOError(message)
package.close()
else:
base, extension = os.path.splitext(self.filename)
filename_template = base + '.{package_id}' + extension
package = None
# Initialize package
package_id = 1
size_uncompressed = 0
if self.format == FileFormat.ZIP:
package = zipfile.ZipFile(
file=filename_template.format(package_id=package_id),
mode='w'
)
elif self.format == FileFormat.TAR:
package = tarfile.open(
name=filename_template.format(package_id=package_id),
mode='w:gz'
)
package_filenames.append(filename_template.format(package_id=package_id))
progress = tqdm(
file_list,
desc="{0: <25s}".format('Compress'),
file=sys.stdout,
leave=False,
disable=self.disable_progress_bar,
ascii=self.use_ascii_progress_bar
)
for item_id, item in enumerate(progress):
if self.disable_progress_bar:
self.logger.info(' {title:<15s} [{item_id:d}/{total:d}] {file:<30s}'.format(
title='Compress ',
item_id=item_id,
total=len(progress),
file=item['source'])
)
if os.path.exists(item['source']):
current_size_uncompressed = os.path.getsize(item['source'])
if size_uncompressed + current_size_uncompressed > size_limit:
# Size limit met, close current package and open a new one.
package.close()
package_id += 1
if self.format == FileFormat.ZIP:
package = zipfile.ZipFile(
file=filename_template.format(package_id=package_id),
mode='w'
)
elif self.format == FileFormat.TAR:
package = tarfile.open(
name=filename_template.format(package_id=package_id),
mode='w:gz'
)
package_filenames.append(
filename_template.format(package_id=package_id)
)
size_uncompressed = 0
if self.format == FileFormat.ZIP:
package.write(
filename=item['source'],
arcname=os.path.relpath(item['target']),
compress_type=zipfile.ZIP_DEFLATED
)
file_info = package.getinfo(os.path.relpath(item['target']))
size_uncompressed += file_info.file_size
elif self.format == FileFormat.TAR:
package.add(
name=item['source'],
arcname=os.path.relpath(item['target'])
)
file_info = package.gettarinfo(
name=item['source'],
arcname=os.path.relpath(item['target'])
)
size_uncompressed += file_info.size
else:
package.close()
message = '{name}: Non-existing file [{filename}] detected while compressing a package [{package}]'.format(
name=self.__class__.__name__,
filename=item['source'],
package=filename_template.format(package_id=package_id)
)
if self.logger:
self.logger.exception(message)
raise IOError(message)
package.close()
return package_filenames