package io.semla.datasource;

import io.semla.config.ShardedDatasourceConfiguration;
import io.semla.model.EntityModel;
import io.semla.query.Pagination;
import io.semla.query.Predicates;
import io.semla.query.Values;
import io.semla.reflect.Member;
import io.semla.reflect.Types;
import io.semla.serialization.annotations.TypeInfo;
import io.semla.serialization.annotations.TypeName;
import io.semla.util.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/semla/datasource/ShardedDatasource.class */
public class ShardedDatasource<T> extends Datasource<T> {
    protected final ShardingStrategy shardingStrategy;
    protected final List<Datasource<T>> datasources;
    protected final boolean rebalacing;

    @TypeName("keyed")
    /* loaded from: input_file:io/semla/datasource/ShardedDatasource$KeyedShardingStrategy.class */
    public static class KeyedShardingStrategy implements ShardingStrategy {
        @Override // io.semla.datasource.ShardedDatasource.ShardingStrategy
        public <T> Datasource<T> selectFor(Object obj, List<Datasource<T>> list) {
            if (obj == null) {
                throw new IllegalArgumentException("Cannot shard on a generated key!");
            }
            return Types.isAssignableTo(obj.getClass(), (Class<?>) Integer.class) ? list.get((((Integer) obj).intValue() - 1) % list.size()) : selectFor(Integer.valueOf(obj.toString().charAt(0)), list);
        }
    }

    @TypeInfo
    /* loaded from: input_file:io/semla/datasource/ShardedDatasource$ShardingStrategy.class */
    public interface ShardingStrategy {
        <T> Datasource<T> selectFor(Object obj, List<Datasource<T>> list);
    }

    public ShardedDatasource(EntityModel<T> entityModel, ShardingStrategy shardingStrategy, boolean z, List<Datasource<T>> list) {
        super(entityModel);
        this.shardingStrategy = shardingStrategy;
        this.rebalacing = z;
        this.datasources = list;
    }

    @Override // io.semla.datasource.Datasource
    public List<Datasource<T>> raw() {
        return this.datasources;
    }

    @Override // io.semla.datasource.Datasource
    public Optional<T> get(Object obj) {
        Optional<T> optional = forKey(obj).get(obj);
        if (this.rebalacing && !optional.isPresent()) {
            AtomicReference atomicReference = new AtomicReference();
            optional = map(datasource -> {
                Optional<T> optional2 = datasource.get(obj);
                if (optional2.isPresent()) {
                    atomicReference.set(datasource);
                }
                return optional2;
            }).filter((v0) -> {
                return v0.isPresent();
            }).map((v0) -> {
                return v0.get();
            }).findFirst();
            if (optional.isPresent()) {
                create((ShardedDatasource<T>) optional.get());
                ((Datasource) atomicReference.get()).delete(obj);
            }
        }
        return optional;
    }

    @Override // io.semla.datasource.Datasource
    public <K> Map<K, T> get(Collection<K> collection) {
        Map<K, T> map = (Map) map(collection, (v0, v1) -> {
            return v0.get(v1);
        }).map((v0) -> {
            return v0.entrySet();
        }).flatMap((v0) -> {
            return v0.stream();
        }).collect(Maps.collect());
        if (this.rebalacing) {
            List list = (List) map.entrySet().stream().filter(entry -> {
                return entry.getValue() == null;
            }).map((v0) -> {
                return v0.getKey();
            }).collect(Collectors.toList());
            if (!list.isEmpty()) {
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                ArrayList arrayList = new ArrayList();
                Stream map2 = map(datasource -> {
                    Map map3 = (Map) datasource.get((Collection) list).entrySet().stream().filter(entry2 -> {
                        return entry2.getValue() != null;
                    }).collect(Maps.collect());
                    if (!map3.isEmpty()) {
                        linkedHashMap.put(datasource, map3.keySet());
                        arrayList.addAll(map3.values());
                    }
                    return map3;
                });
                map.getClass();
                map2.forEach(map::putAll);
                if (!arrayList.isEmpty()) {
                    create((Collection) arrayList);
                }
                if (!linkedHashMap.isEmpty()) {
                    linkedHashMap.forEach((v0, v1) -> {
                        v0.delete(v1);
                    });
                }
            }
        }
        return map;
    }

    @Override // io.semla.datasource.Datasource
    public void create(T t) {
        forEntity(t).create((Datasource<T>) t);
    }

    @Override // io.semla.datasource.Datasource
    public void create(Collection<T> collection) {
        foreach(collection, (v0, v1) -> {
            v0.create(v1);
        });
    }

    @Override // io.semla.datasource.Datasource
    public void update(T t) {
        forEntity(t).update((Datasource<T>) t);
    }

    @Override // io.semla.datasource.Datasource
    public void update(Collection<T> collection) {
        foreach(collection, (v0, v1) -> {
            v0.update(v1);
        });
    }

    @Override // io.semla.datasource.Datasource
    public boolean delete(Object obj) {
        return forKey(obj).delete(obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.semla.datasource.Datasource
    public long delete(Collection<?> collection) {
        return ((Long) map(collection, (v0, v1) -> {
            return v0.delete(v1);
        }).reduce((v0, v1) -> {
            return Long.sum(v0, v1);
        }).orElse(0L)).longValue();
    }

    @Override // io.semla.datasource.Datasource
    public Optional<T> first(Predicates<T> predicates, Pagination<T> pagination) {
        return pagination.paginate(map(datasource -> {
            Optional<T> first = datasource.first(predicates, pagination(pagination));
            if (this.rebalacing && first.isPresent()) {
                Datasource<T> forEntity = forEntity(first.get());
                if (!forEntity.equals(datasource)) {
                    forEntity.create((Datasource<T>) first.get());
                    datasource.delete(model().key().member().getOn(first.get()));
                }
            }
            return first;
        }).filter((v0) -> {
            return v0.isPresent();
        }).map((v0) -> {
            return v0.get();
        })).findFirst();
    }

    @Override // io.semla.datasource.Datasource
    public List<T> list(Predicates<T> predicates, Pagination<T> pagination) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        List<T> list = (List) pagination.paginate(map(datasource -> {
            List<T> list2 = datasource.list(predicates, pagination(pagination));
            if (this.rebalacing) {
                list2.forEach(obj -> {
                    Datasource<T> forEntity = forEntity(obj);
                    if (forEntity.equals(datasource)) {
                        return;
                    }
                    linkedHashMap.put(obj, datasource);
                    ((Collection) linkedHashMap2.computeIfAbsent(forEntity, datasource -> {
                        return new ArrayList();
                    })).add(obj);
                });
            }
            return list2;
        }).flatMap((v0) -> {
            return v0.stream();
        })).collect(Collectors.toList());
        if (!linkedHashMap.isEmpty()) {
            linkedHashMap2.forEach((v0, v1) -> {
                v0.create(v1);
            });
            ((Map) linkedHashMap.entrySet().stream().collect(Collectors.groupingBy((v0) -> {
                return v0.getValue();
            }))).forEach((datasource2, list2) -> {
                datasource2.delete((Collection<?>) list2.stream().map((v0) -> {
                    return v0.getKey();
                }).map(obj -> {
                    return model().key().member().getOn(obj);
                }).collect(Collectors.toList()));
            });
        }
        return list;
    }

    @Override // io.semla.datasource.Datasource
    public long patch(Values<T> values, Predicates<T> predicates, Pagination<T> pagination) {
        if (!pagination.isPaginated()) {
            return ((Long) map(datasource -> {
                return Long.valueOf(datasource.patch(values, predicates, pagination));
            }).reduce((v0, v1) -> {
                return Long.sum(v0, v1);
            }).orElse(0L)).longValue();
        }
        Stream<T> stream = list(predicates, pagination).stream();
        values.getClass();
        update((Collection) ((List) stream.map(values::apply).collect(Collectors.toList())));
        return r0.size();
    }

    @Override // io.semla.datasource.Datasource
    public long delete(Predicates<T> predicates, Pagination<T> pagination) {
        if (!pagination.isPaginated()) {
            return ((Long) map(datasource -> {
                return Long.valueOf(datasource.delete(predicates, pagination));
            }).reduce((v0, v1) -> {
                return Long.sum(v0, v1);
            }).orElse(0L)).longValue();
        }
        Stream<T> stream = list(predicates, pagination).stream();
        Member<T> member = model().key().member();
        member.getClass();
        List list = (List) stream.map(member::getOn).collect(Collectors.toList());
        return ((Long) map(datasource2 -> {
            return Long.valueOf(datasource2.delete((Collection<?>) list));
        }).reduce((v0, v1) -> {
            return Long.sum(v0, v1);
        }).orElse(0L)).longValue();
    }

    @Override // io.semla.datasource.Datasource
    public long count(Predicates<T> predicates) {
        return ((Long) map(datasource -> {
            return Long.valueOf(datasource.count(predicates));
        }).reduce((v0, v1) -> {
            return Long.sum(v0, v1);
        }).orElse(0L)).longValue();
    }

    protected <E> Stream<E> map(Function<Datasource<T>, E> function) {
        return (Stream<E>) this.datasources.parallelStream().map(function);
    }

    protected <K, E> Stream<E> map(Collection<K> collection, BiFunction<Datasource<T>, Collection<K>, E> biFunction) {
        return (Stream<E>) ((Map) collection.stream().collect(Collectors.groupingBy(this::forKey))).entrySet().parallelStream().map(entry -> {
            return biFunction.apply(entry.getKey(), entry.getValue());
        });
    }

    protected void foreach(Collection<T> collection, BiConsumer<Datasource<T>, Collection<T>> biConsumer) {
        ((Map) collection.stream().collect(Collectors.groupingBy(this::forEntity))).entrySet().parallelStream().forEach(entry -> {
            biConsumer.accept(entry.getKey(), entry.getValue());
        });
    }

    protected Datasource<T> forEntity(T t) {
        return this.shardingStrategy.selectFor(model().key().member().getOn(t), this.datasources);
    }

    protected Datasource<T> forKey(Object obj) {
        return this.shardingStrategy.selectFor(obj, this.datasources);
    }

    protected Pagination<T> pagination(Pagination<T> pagination) {
        return pagination.copy().startAt(Math.max(0, pagination.start() / this.datasources.size())).limitTo(pagination.limit() < Integer.MAX_VALUE ? pagination.limit() * this.datasources.size() : Integer.MAX_VALUE);
    }

    public static ShardedDatasourceConfiguration configure() {
        return new ShardedDatasourceConfiguration();
    }
}
