package com.taosdata.jdbc.ws.tmq;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.taosdata.jdbc.TSDBError;
import com.taosdata.jdbc.TSDBErrorNumbers;
import com.taosdata.jdbc.common.Consumer;
import com.taosdata.jdbc.enums.WSFunction;
import com.taosdata.jdbc.tmq.*;
import com.taosdata.jdbc.ws.FutureResponse;
import com.taosdata.jdbc.ws.InFlightRequest;
import com.taosdata.jdbc.ws.Transport;
import com.taosdata.jdbc.ws.entity.Code;
import com.taosdata.jdbc.ws.entity.FetchBlockResp;
import com.taosdata.jdbc.ws.entity.Request;
import com.taosdata.jdbc.ws.entity.Response;
import com.taosdata.jdbc.ws.tmq.entity.*;

import java.nio.ByteOrder;
import java.sql.SQLException;
import java.time.Duration;
import java.util.*;

import static com.taosdata.jdbc.TSDBErrorNumbers.ERROR_TMQ_VGROUP_NOT_FOUND;

public class WSConsumer<V> implements Consumer<V> {
    private Transport transport;
    private ConsumerParam param;
    private TMQRequestFactory factory;
    private long offset = 0L;

    @Override
    public void create(Properties properties) throws SQLException {
        factory = new TMQRequestFactory();
        param = new ConsumerParam(properties);
        InFlightRequest inFlightRequest = new InFlightRequest(param.getConnectionParam().getRequestTimeout()
                , param.getConnectionParam().getMaxRequest());
        transport = new Transport(WSFunction.TMQ, param.getConnectionParam(), inFlightRequest);

        transport.setTextMessageHandler(message -> {
            JSONObject jsonObject = JSON.parseObject(message);
            ConsumerAction action = ConsumerAction.of(jsonObject.getString("action"));
            Response response = jsonObject.toJavaObject(action.getResponseClazz());
            FutureResponse remove = inFlightRequest.remove(response.getAction(), response.getReqId());
            if (null != remove) {
                remove.getFuture().complete(response);
            }
        });
        transport.setBinaryMessageHandler(byteBuffer -> {
            byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
            byteBuffer.position(8);
            // request_id
            long id = byteBuffer.getLong();
            byteBuffer.position(24);
            FutureResponse remove = inFlightRequest.remove(ConsumerAction.FETCH_BLOCK.getAction(), id);
            if (null != remove) {
                FetchBlockResp fetchBlockResp = new FetchBlockResp(id, byteBuffer);
                remove.getFuture().complete(fetchBlockResp);
            }
        });

        Transport.checkConnection(transport, param.getConnectionParam().getConnectTimeout());
    }

    @Override
    public void subscribe(Collection<String> topics) throws SQLException {
        Request request = factory.generateSubscribe(param.getConnectionParam().getUser()
                , param.getConnectionParam().getPassword()
                , param.getConnectionParam().getDatabase()
                , param.getGroupId()
                , param.getClientId()
                , param.getOffsetRest()
                , topics.toArray(new String[0])
                , String.valueOf(param.isAutoCommit())
                , param.getAutoCommitInterval()
                , param.getSnapshotEnable()
                , param.getMsgWithTableName()
        );
        SubscribeResp response = (SubscribeResp) transport.send(request);
        if (Code.SUCCESS.getCode() != response.getCode()) {
            throw new SQLException("subscribe topic error, code: 0x(" + Integer.toHexString(response.getCode())
                    + "), message: " + response.getMessage());
        }
    }

    @Override
    public void unsubscribe() throws SQLException {
        Request request = factory.generateUnsubscribe();
        UnsubscribeResp response = (UnsubscribeResp) transport.send(request);
        if (Code.SUCCESS.getCode() != response.getCode()) {
            throw new SQLException("unsubscribe topic error, code: 0x(" + Integer.toHexString(response.getCode())
                    + "), message: " + response.getMessage() + ", timing: " + response.getTiming());
        }
    }

    @Override
    public Set<String> subscription() throws SQLException {
        throw TSDBError.createSQLException(TSDBErrorNumbers.ERROR_UNSUPPORTED_METHOD);
    }

    @Override
    public ConsumerRecords<V> poll(Duration timeout, Deserializer<V> deserializer) throws SQLException {
        Request request = factory.generatePoll(timeout.toMillis());
        PollResp pollResp = (PollResp) transport.send(request);
        if (Code.SUCCESS.getCode() != pollResp.getCode()) {
            throw new SQLException("consumer poll error, code: 0x(" + Integer.toHexString(pollResp.getCode()) + "), message: " + pollResp.getMessage());
        }
        if (!pollResp.isHaveMessage())
            return ConsumerRecords.emptyRecord();

        offset = pollResp.getMessageId();
        ConsumerRecords<V> records = new ConsumerRecords<>(pollResp.getMessageId());
        try (WSConsumerResultSet rs = new WSConsumerResultSet(transport, factory, pollResp.getMessageId(), pollResp.getDatabase())) {
            while (rs.next()) {
                String topic = pollResp.getTopic();
                String dbName = pollResp.getDatabase();
                int vGroupId = pollResp.getVgroupId();
                TopicPartition tp = new TopicPartition(topic, dbName, vGroupId);

                V v = deserializer.deserialize(rs, topic, dbName);
                ConsumerRecord<V> r = new ConsumerRecord<>(topic, dbName, vGroupId, pollResp.getOffset(), v);
                records.put(tp, r);
            }
        }

        if (param.isAutoCommit()) {
            this.commitSync();
        }
        return records;
    }

    @Override
    public synchronized void commitSync() throws SQLException {
        if (0 != offset) {
            CommitResp commitResp = (CommitResp) transport.send(factory.generateCommit(offset));
            if (Code.SUCCESS.getCode() != commitResp.getCode())
                throw new SQLException("consumer commit error. code: 0x(" + Integer.toHexString(commitResp.getCode()) + "), message: " + commitResp.getMessage());
            offset = 0;
        }
    }

    @Override
    public void close() throws SQLException {
        transport.close();
    }

    @Override
    public void commitAsync(OffsetCommitCallback<V> callback) {
        // nothing to do
    }

    @Override
    public void seek(TopicPartition partition, long offset) throws SQLException {
        Request request = factory.generateSeek(partition.getTopic(), partition.getVGroupId(), offset);
        SeekResp resp = (SeekResp) transport.send(request);
        if (Code.SUCCESS.getCode() != resp.getCode()) {
            throw new SQLException("consumer seek error, code: 0x(" + Integer.toHexString(resp.getCode())
                    + "), message: " + resp.getMessage() + ", timing: " + resp.getTiming());
        }
    }

    @Override
    public long position(TopicPartition partition) throws SQLException {
        return Arrays.stream(getAssignment(partition.getTopic())).
                filter(a -> a.getVgId() == partition.getVGroupId())
                .findFirst()
                .orElseThrow(() -> TSDBError.createIllegalStateException(ERROR_TMQ_VGROUP_NOT_FOUND))
                .getCurrentOffset();
    }

    @Override
    public Map<TopicPartition, Long> position(String topic) throws SQLException {
        return Arrays.stream(getAssignment(topic))
                .collect(HashMap::new, (m, a) -> m.put(new TopicPartition(topic, a.getVgId()), a.getCurrentOffset()), HashMap::putAll);
    }

    @Override
    public Map<TopicPartition, Long> beginningOffsets(String topic) throws SQLException {
        return Arrays.stream(getAssignment(topic))
                .collect(HashMap::new, (m, a) -> m.put(new TopicPartition(topic, a.getVgId()), a.getBegin()), HashMap::putAll);
    }

    @Override
    public Map<TopicPartition, Long> endOffsets(String topic) throws SQLException {
        return Arrays.stream(getAssignment(topic))
                .collect(HashMap::new, (m, a) -> m.put(new TopicPartition(topic, a.getVgId()), a.getEnd()), HashMap::putAll);
    }

    private Assignment[] getAssignment(String topic) throws SQLException {
        Request request = factory.generateAssignment(topic);
        AssignmentResp resp = (AssignmentResp) transport.send(request);
        if (Code.SUCCESS.getCode() != resp.getCode()) {
            throw new SQLException("consumer assignment error, code: 0x(" + Integer.toHexString(resp.getCode())
                    + "), message: " + resp.getMessage() + ", timing: " + resp.getTiming());
        }
        return resp.getAssignment();
    }
}
