Training on Zhuang lab data

In order to train on the Zhuang lab data please see the config files labeled "Zhuang".

The setup is more or less the same as for the AIBS case. One easy solution is to simply load each metadata / anndata object indepently as a celltransformer.data.SimpleMaskSampler object and concatenate them together post-hoc using torch.utils.data.ConcatDataset. This is what we do.

For reference, here is a snippet of code from the ZhuangTrainer.load_data function (see scripts/training/train_zhuang.py) and some annotation by me:

def load_data(self, config: DictConfig):

    all_dfs = []
    all_cls = set()

    for df_path in config.data.metadata_path: 

        # We loop over all the dataframes first, so we can generate a consistent list 
        of all the celltypes in the datasets.

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", pd.errors.DtypeWarning)
            metadata = pd.read_csv(df_path)

        metadata["x"] = metadata["x"] * 100
        metadata["y"] = metadata["y"] * 100

        metadata = metadata.reset_index(drop=True)

        all_cls.update(metadata[config.data.celltype_colname].unique())
        all_dfs.append(metadata)

    le = LabelEncoder()
    le.fit(sorted(all_cls))


    # Now that we have this (`le` consistent encoder) we can use this as we loop again 
    over the metadata/anndata pairs, creating for each one a `CenterMaskSampler` pair 
    and appending to the `trn_samplers` and `valid_samplers` lists.

    trn_samplers = []
    valid_samplers = []

    for df, anndata_path in zip(all_dfs, config.data.adata_path):
        df["cell_type"] = le.transform(df[config.data.celltype_colname])
        df["cell_type"] = df["cell_type"].astype(int)

        df['cell_label'] = df['cell_label'].astype(str)

        df = df[['cell_type', 'cell_label', 'x', 'y', 'brain_section_label']]

        adata = ad.read_h5ad(anndata_path)
        adata = adata[df["cell_label"]]

        train_indices, valid_indices = train_test_split(
            range(len(adata)), train_size=config.data.train_pct
        )

        train_sampler = CenterMaskSampler(
            metadata=df,
            adata=adata,
            patch_size=config.data.patch_size,
            cell_id_colname=config.data.cell_id_colname,
            cell_type_colname="cell_type",
            tissue_section_colname=config.data.tissue_section_colname,
            max_num_cells=config.data.neighborhood_max_num_cells,
            indices=train_indices,
        )

        valid_sampler = CenterMaskSampler(
            metadata=df,
            adata=adata,
            patch_size=config.data.patch_size,
            cell_id_colname=config.data.cell_id_colname,
            cell_type_colname="cell_type",
            tissue_section_colname=config.data.tissue_section_colname,
            max_num_cells=config.data.neighborhood_max_num_cells,
            indices=valid_indices,
        )

        trn_samplers.append(train_sampler)
        valid_samplers.append(valid_sampler)

    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset(trn_samplers),
        batch_size=config.data.batch_size,
        num_workers=config.data.num_workers,
        pin_memory=False,
        shuffle=True, 
        collate_fn=collate,
        prefetch_factor=4, # muddling with this a bit can improve performance, depends on your setup

    )