Source code for axs.catalog


from axs.axsframe import AxsFrame
from axs import Constants
from pyspark.sql import functions as F


[docs]class AxsCatalog: """ Implements high-level operations on AXS tables, such as loading, saving, renaming and dropping. AXS tables are Spark Parquet tables, but are bucketed and sorted in a special way. `AxsCatalog` relies on Spark Catalog for loading table data and it needs an active `SparkSession` object (with Hive support enabled) for initialization. This is also necessary because information about available AXS tables is persisted in a special AXS table in Spark metastore DB. """ ZONE_HEIGHT = Constants.ONE_AMIN NGBR_BORDER_HEIGHT = 10 * Constants.ONE_ASEC RA_COLNAME = "ra" DEC_COLNAME = "dec" DUP_COLNAME = "dup" ZONE_COLNAME = "zone" NUM_BUCKETS = Constants.NUM_BUCKETS __instance = None def __init__(self, spark_session=None): """ Initializes the catalog using connection to the Spark metastore of the current `SparkSession`. :param sparkSession: An active and initialized Spark session which can be used for accessing Spark catalog. Can be None if AxsCatalog has already been initialized at least once. """ if spark_session is None and AxsCatalog.__instance is None: raise AttributeError("Spark session is none but AxsCatalog hasn't been initialized yet.") if spark_session is None: self.spark = AxsCatalog.__instance.spark else: self.spark = spark_session # _CatalogUtils is a bridge to AXS Java functions self._CatalogUtils = self.spark.sparkContext._jvm.org.dirac.axs.util.CatalogUtils() self._CatalogUtils.setup(self.spark._jsparkSession) AxsCatalog.__instance = self
[docs] def load(self, table_name): """ Loads a known AXS table from a Spark catalog and returns it as an `AxsFrame`. """ # if table_name not in AxsCatalog._AXS_TABLES: info = self.table_info(table_name) table = self.spark.read.table(table_name) return AxsFrame(table, info)
[docs] def import_existing_table(self, table_name, path, num_buckets=500, zone_height=Constants.ONE_AMIN, import_into_spark=True, update_spark_bucketing=True, bucket_col='zone', ra_col='ra', dec_col='dec', lightcurves=False, lc_cols=None): """ Imports an existing, properly bucketed and sorted Parquet file into AXS catalog. If `import_into_spark` is True, the table will also be imported into Spark catalog. `bucket_col`, `ra_col` and `dec_col` allow for changing the default bucketing and sorting columns. If `lightcurves` is True, then some columns are expected to be array columns. Those are specified in `lc_cols`. array. :param table_name: The table name into which to import the data. :param path: The path to the bucketed Parquet file to be imported (needed only for Spark import). :param num_buckets: Number of buckets in the input Parquet file. :param zone_height: Zone height used for data partitioning. :param import_into_spark: Whether to also import the table into Spark. If `False`, the table should already exist in Spark metastore. :param update_spark_bucketing: Whether to also update bucketing info in Spark metastore. :param bucket_col: The column used for data bucketing. :param ra_col: The name of column containing RA coordinates. :param dec_col: The name of column containing DEC coordinates. :param lightcurves: Whether the table contains lightcurve data as array columns. :param lc_cols: Comma-separated list of names of array columns containing lightcurve data. """ if self._CatalogUtils.tableExists(table_name): raise AttributeError("Table %s already exists in AXS catalog" % table_name) if import_into_spark: self.spark.catalog.createTable(table_name, path, "parquet") if update_spark_bucketing: self._CatalogUtils.updateSparkMetastoreBucketing(table_name, num_buckets) self._CatalogUtils.saveNewTable(table_name, num_buckets, zone_height, bucket_col, ra_col, dec_col, lightcurves, lc_cols)
[docs] def table_info(self, table_name): """ Returns a known AxsFrame table info as a dictionary with these keys: - `table_id` - Internal table ID - `table_name` - Name of the table - `num_buckets` - Number of buckets used for partitioning the table data - `zone_height` - Zone height used for data partitioning - `bucket_col` - the column name used for bucketing - `ra_col` - the column containing RA coordinates - `dec_col` - the column containing DEC coordinates - `has_lightcurves` - whether the table contains lightcurve data as array columns - `lc_columns` - a list of array columns containing lightcurve data :param table_name: Table name for which to fetch info. """ # return AxsCatalog._AXS_TABLES tbls = self._CatalogUtils.listTables() for x in tbls: if x.getTableName() == table_name: return {'table_id': x.getTableId(), 'table_name': x.getTableName(), 'num_buckets': x.getNumBuckets(), 'zone_height': x.getZoneHeight(), 'bucket_col': x.getBucketCol(), 'ra_col': x.getRaCol(), 'dec_col': x.getDecCol(), 'has_lightcurves': x.isLightcurves(), 'lc_columns': x.getLcColumns()} raise AttributeError("Table %s not found in AXS catalog" % table_name)
[docs] def list_tables(self): """ Returns a list of a known AxsFrame tables as a dictionary where the keys are table names and the values are again dictionaries with these fields: - `table_id` - Internal table ID - `table_name` - Name of the table - `num_buckets` - Number of buckets used for partitioning the table data - `zone_height` - Zone height used for data partitioning - `bucket_col` - the column name used for bucketing - `ra_col` - the column containing RA coordinates - `dec_col` - the column containing DEC coordinates - `has_lightcurves` - whether the table contains lightcurve data as array columns - `lc_columns` - a list of array columns containing lightcurve data """ tbls = self._CatalogUtils.listTables() res = {} for x in tbls: res[x.getTableName()] = {'table_id': x.getTableId(), 'table_name': x.getTableName(), 'num_buckets': x.getNumBuckets(), 'zone_height': x.getZoneHeight(), 'bucket_col': x.getBucketCol(), 'ra_col': x.getRaCol(), 'dec_col': x.getDecCol(), 'has_lightcurves': x.isLightcurves(), 'lc_columns': x.getLcColumns()} return res
[docs] def save_axs_table(self, df, tblname, repartition=True, calculate_zone=False, num_buckets=Constants.NUM_BUCKETS, zone_height=ZONE_HEIGHT): """ Saves a Spark DataFrame as an AxsFrame under the name `tblname`. Also saves the table as a Spark table in the Spark catalog. The table will be bucketed into `AxsCatalog.NUM_BUCKETS` buckets, each bucket sorted by `zone` and `ra` columns. Note: If saving intermediate results from cross-matching two AxsFrames the DataFrame should already be partitioned appropriately. `repartition` should then be set to `False` to speed things up. To obtained the saved table, use the `load()` method. :param df: Spark DataFrame or AxsFrame to save as AXS table. :param tblname: Table name to use for saving. :param repartition: Whether to repartition the data by zone before saving. :param calculate_zone: Whether to first add `zone` and `dup` columns to `df`. :param num_buckets: Number of buckets to use for data partitioning. """ # if tblname in AxsCatalog._AXS_TABLES: if self._CatalogUtils.tableExists(tblname): raise Exception("Table already exists: " + tblname) if AxsCatalog.DEC_COLNAME not in df.columns or AxsCatalog.RA_COLNAME not in df.columns: raise Exception("Cannot save as AXS table: '"+AxsCatalog.DEC_COLNAME+ "' or '"+AxsCatalog.RA_COLNAME+"' columns not found.") if AxsCatalog.ZONE_COLNAME not in df.columns and not calculate_zone: raise Exception("Cannot save as AXS table: '"+AxsCatalog.ZONE_COLNAME+ "' column not found and calculate_zone is not set.") old_partitions_conf = df.sql_ctx.getConf("spark.sql.shuffle.partitions") df.sql_ctx.setConf("spark.sql.shuffle.partitions", AxsCatalog.NUM_BUCKETS) newdf = df if calculate_zone: newdf = self.calculate_zone(newdf, zone_height) if repartition: newdf = newdf.repartition("zone") newdf.write.format("parquet"). \ bucketBy(num_buckets, "zone").sortBy("zone", "ra"). \ saveAsTable(tblname) self._CatalogUtils.saveNewTable(tblname, num_buckets, zone_height, AxsCatalog.ZONE_COLNAME, AxsCatalog.RA_COLNAME, AxsCatalog.DEC_COLNAME, False, None) df.sql_ctx.setConf("spark.sql.shuffle.partitions", old_partitions_conf)
[docs] def calculate_zone(self, df, zone_height=ZONE_HEIGHT): """ Adds `zone` and `dup` columns to the `df` DataFrame. `df` needs to have a `dec` column for calculating zones and must not already have `zone` and `dup` columns. Data in the lower border strip of each zone is duplicated to the zone below it. `dup` column of those rows contains 1 and 0 otherwise. :param df: The input DataFrame for which to calculate :param zone_height: Zone height to be used for data partitioning :return: The new AxsFrame """ if AxsCatalog.ZONE_COLNAME in df.columns: raise Exception("Cannot save as AXS table: '" + AxsCatalog.ZONE_COLNAME + "' column already exists.") if AxsCatalog.DUP_COLNAME in df.columns: raise Exception("Cannot save as AXS table: '"+AxsCatalog.DUP_COLNAME+"' column already exists") return df.where(((df.dec + 90) > zone_height) & ( (df.dec + 90) % zone_height < AxsCatalog.NGBR_BORDER_HEIGHT)).\ withColumn("zone", ((df.dec + 90) / zone_height - 1).cast("long")). \ withColumn("dup", F.lit(1)).\ union(df.withColumn("zone", ((df.dec + 90) / zone_height).cast("long")). withColumn("dup", F.lit(0)))
[docs] def add_increment(self, table_name, increment_df, rename_to=None, temp_tbl_name=None): """ Adds a new batch of data contained in the `increment_df` DataFrame (or AxsFrame) to the persisted AXS table `table_name`. The old table will be renamed to `rename_to`, if set, or to "`table_name` + _YYYYMMDDhhmm" otherwise. The data will be first saved into `temp_tbl_name` before renaming the main table. `increment_df` needs to have the same schema as the table `table_name`. :param table_name: The table to which to add the new data. :param increment_df: The data to add to the existing table. Needs to have the appropriate schema. :param rename_to: New table name for the original data. :param temp_tbl_name: Temporary table name to use (`table_name` + "_temp" is the default) before rename operation. :return: The table name to which the original table has been renamed """ if temp_tbl_name is None: temp_tbl_name = table_name + "_temp" if rename_to is None: import datetime ts = datetime.datetime.now().strftime("%Y%m%d%H%M") rename_to = table_name + "_" + ts old = self.load(table_name) old = old.where(old.dup == 0) self.save_axs_table(old.union(self.calculate_zone(increment_df)), temp_tbl_name) self.rename_table(table_name, rename_to) self.rename_table(temp_tbl_name, table_name) self.drop_table(temp_tbl_name) return rename_to
[docs] def rename_table(self, table_name, new_name): """ Renames an AxsTable `table_name` to `new_name`. Also renames the table in the Spark catalog. :param table_name: An existing table to rename. :param new_name: The new name """ # if new_name in AxsCatalog._AXS_TABLES: if self._CatalogUtils.tableExists(new_name): raise AttributeError("Table "+new_name+" already exists!") # if table_name not in AxsCatalog._AXS_TABLES: if not self._CatalogUtils.tableExists(table_name): raise AttributeError("Table "+table_name+" does not exist!") self._CatalogUtils.renameTable(table_name, new_name) # AxsCatalog._AXS_TABLES.remove(table_name) # AxsCatalog._AXS_TABLES.append(new_name) self.spark.sql("alter table " + table_name + " rename to " + new_name)
[docs] def drop_table(self, table_name, drop_spark=True): """ Drops a table from both AXS and Spark catalogs. :param table_name: Table to drop. :param drop_spark: Whether to drop the table in Spark catalog, too. Default is True. """ try: self._CatalogUtils.deleteTable(table_name) except Exception as e: print(e) if drop_spark: try: self.spark.sql("drop table " + table_name) except Exception as e: print(e)